Spaces:
Paused
Paused
from PIL import Image | |
import numpy as np | |
import torch | |
from torchvision import transforms | |
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria | |
from rembg import remove | |
def create_binary_mask(image): | |
grayscale = image.convert("L") | |
mask = grayscale.point(lambda x: 255 if x > 1 else 0, '1') | |
return mask | |
def Dataset_evaluate_MoMA(image_pil, prompt,subject, moMA_main_modal): | |
LLaVa_processor = moMA_main_modal.image_processor_llava | |
llava_config = moMA_main_modal.model_llava.config | |
transform = transforms.Compose([ | |
transforms.Resize((512, 512)), | |
]) | |
mask_pil = create_binary_mask(remove(image_pil)) # Image.open(mask_path) | |
blip2_opt = prompt | |
if transform is not None: | |
image_pil = transform(image_pil) | |
mask_pil = transform(mask_pil) | |
mask_pil = np.array(mask_pil) | |
mask_pil = mask_pil[:,:,0] if len(mask_pil.shape)==3 else mask_pil | |
image = torch.from_numpy(np.array(image_pil)).permute(2,0,1) | |
mask = (torch.clamp((torch.from_numpy(mask_pil).unsqueeze(0)).float(),min=0.0,max=1.0)>0).float() | |
res = {'image': (image/127.5-1).unsqueeze(0),\ | |
'mask': mask.unsqueeze(0), \ | |
'text': [blip2_opt]} | |
image_wb = image * mask + torch.ones_like(image)* (1-mask)*255 | |
image_pil = Image.fromarray(image_wb.permute(1,2,0).numpy().astype(np.uint8)) | |
res['llava_processed'] = process_images([image_pil], LLaVa_processor, llava_config) | |
res['label'] = [subject] | |
return res | |