mq / MagicQuill /llava_new.py
LIU, Zichen
initial commit
d4733f5
raw
history blame
5.34 kB
import torch
from transformers import TextStreamer
import webcolors
import os
import random
from collections import Counter
import numpy as np
from torchvision import transforms
from .magic_utils import get_colored_contour, find_different_colors, get_bounding_box_from_mask
from .LLaVA.llava.conversation import conv_templates, SeparatorStyle
from .LLaVA.llava.model.builder import load_pretrained_model
from .LLaVA.llava.mm_utils import get_model_name_from_path, expand2square, tokenizer_image_token
from .LLaVA.llava.constants import (
IMAGE_TOKEN_INDEX,
DEFAULT_IMAGE_TOKEN,
DEFAULT_IM_START_TOKEN,
DEFAULT_IM_END_TOKEN,
IMAGE_PLACEHOLDER,
)
import re
class LLaVAModel:
def __init__(self):
current_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(current_dir, "../models/llava-v1.5-7b-finetune-clean")
self.tokenizer, self.model, self.image_processor, self.context_len = load_pretrained_model(
model_path=model_path,
model_base=None,
model_name=get_model_name_from_path(model_path),
load_4bit=True
)
def generate_description(self, images, question):
qs = question
image_token_se = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN
if IMAGE_PLACEHOLDER in qs:
if self.model.config.mm_use_im_start_end:
qs = re.sub(IMAGE_PLACEHOLDER, image_token_se, qs)
else:
qs = re.sub(IMAGE_PLACEHOLDER, DEFAULT_IMAGE_TOKEN, qs)
else:
if self.model.config.mm_use_im_start_end:
qs = image_token_se + "\n" + qs
else:
qs = DEFAULT_IMAGE_TOKEN + "\n" + qs
images_tensor = []
image_sizes = []
to_pil = transforms.ToPILImage()
for image in images:
image = image.clone().permute(2, 0, 1).cpu()
image = to_pil(image)
image_sizes.append(image.size)
image = expand2square(image, tuple(int(x) for x in self.image_processor.image_mean))
image = self.image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0]
images_tensor.append(image.half())
conv = conv_templates["llava_v1"].copy()
conv.append_message(conv.roles[0], qs)
conv.append_message(conv.roles[1], None)
prompt = conv.get_prompt()
input_ids = (
tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt")
.unsqueeze(0)
.cuda()
)
with torch.inference_mode():
output_ids = self.model.generate(
input_ids,
images=images_tensor,
image_sizes=image_sizes,
temperature=0.2,
do_sample=True,
use_cache=True,
)
outputs = self.tokenizer.decode(output_ids[0]).strip()
outputs = outputs.split('>')[1].split('<')[0]
# print(outputs)
return outputs
def process(self, image, colored_image, add_mask):
description = ""
answer1 = ""
answer2 = ""
image_with_sketch = image.clone()
if torch.sum(add_mask).item() > 0:
x_min, y_min, x_max, y_max = get_bounding_box_from_mask(add_mask)
# print(x_min, y_min, x_max, y_max)
question = f"This is an 'I draw, you guess' game. I will upload an image containing some sketches. To help you locate the sketch, I will give you the normalized bounding box coordinates of the sketch where their original coordinates are divided by the image width and height. The top-left corner of the bounding box is at ({x_min}, {y_min}), and the bottom-right corner is at ({x_max}, {y_max}). Now tell me, what am I trying to draw with these sketches in the image?"
# image_with_sketch[add_mask > 0.5] = 1.0
bool_add_mask = add_mask > 0.5
mean_brightness = image_with_sketch[bool_add_mask].mean()
if mean_brightness > 0.8:
image_with_sketch[bool_add_mask] = 0.0
else:
image_with_sketch[bool_add_mask] = 1.0
answer1 = self.generate_description([image_with_sketch.squeeze() * 255], question)
print(answer1)
if not torch.equal(image, colored_image):
color = find_different_colors(image.squeeze() * 255, colored_image.squeeze() * 255)
image_with_bbox, colored_mask = get_colored_contour(colored_image.squeeze() * 255, image.squeeze() * 255)
x_min, y_min, x_max, y_max = get_bounding_box_from_mask(colored_mask)
question = f"The user will upload an image containing some contours in red color. To help you locate the contour, I will give you the normalized bounding box coordinates where their original coordinates are divided by the image width and height. The top-left corner of the bounding box is at ({x_min}, {y_min}), and the bottom-right corner is at ({x_max}, {y_max}). You need to identify what is inside the contours using a single word or phrase."
answer2 = color + ', ' + self.generate_description([image_with_bbox.squeeze() * 255], question)
print(answer2)
return (description, answer1, answer2)