|
import argparse |
|
import torch |
|
from typing import Dict, List, Any |
|
|
|
from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN |
|
from llava.conversation import conv_templates, SeparatorStyle |
|
from llava.model.builder import load_pretrained_model |
|
from llava.utils import disable_torch_init |
|
from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria |
|
from configs import * |
|
|
|
|
|
from PIL import Image |
|
|
|
import requests |
|
import base64 |
|
from PIL import Image |
|
from io import BytesIO |
|
from transformers import TextStreamer |
|
|
|
|
|
|
|
class EndpointHandler(): |
|
def __init__(self, path = MODEL_PATH): |
|
disable_torch_init() |
|
self.model_path = MODEL_PATH |
|
self.model_base = MODEL_BASE |
|
self.model_name = get_model_name_from_path(self.model_path) |
|
self.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model( |
|
model_path=self.model_path, |
|
model_name=self.model_name, |
|
load_8bit=LOAD_8BIT, |
|
load_4bit=LOAD_4BIT, |
|
device=self.device, |
|
) |
|
|
|
if "llama-2" in self.model_name.lower(): |
|
self.conv_mode = "llava_llama_2" |
|
elif "v1" in self.model_name.lower(): |
|
self.conv_mode = "llava_v1" |
|
elif "mpt" in self.model_name.lower(): |
|
self.conv_mode = "mpt" |
|
else: |
|
self.conv_mode = "llava_v0" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: |
|
|
|
self.conv = conv_templates[self.conv_mode].copy() |
|
if "mpt" in self.model_name.lower(): |
|
self.roles = ("user", "assistant") |
|
else: |
|
self.roles = self.conv.roles |
|
|
|
|
|
image_encoded = data.pop("inputs", data) |
|
|
|
text = data["text"] |
|
|
|
|
|
image = self.decode_base64_image(image_encoded) |
|
|
|
|
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
|
|
model_config = {"image_aspect_ratio": IMAGE_ASPECT_RATIO} |
|
|
|
image_tensor = process_images([image], self.image_processor, model_config) |
|
|
|
image_tensor = image_tensor.to(self.model.device, dtype = torch.float16) |
|
|
|
while True: |
|
|
|
|
|
inp = text |
|
|
|
if image is not None: |
|
|
|
if self.model.config.mm_use_im_start_end: |
|
inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp |
|
else: |
|
inp = DEFAULT_IMAGE_TOKEN + '\n' + inp |
|
self.conv.append_message(self.conv.roles[0], inp) |
|
image = None |
|
else: |
|
|
|
self.conv.append_message(self.conv.roles[0], inp) |
|
self.conv.append_message(self.conv.roles[1], None) |
|
prompt = self.conv.get_prompt() |
|
|
|
input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() |
|
stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 |
|
keywords = [stop_str] |
|
stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) |
|
streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
|
with torch.inference_mode(): |
|
output_ids = self.model.generate( |
|
input_ids, |
|
images=image_tensor, |
|
do_sample=True, |
|
temperature=TEMPERATURE, |
|
max_new_tokens=MAX_NEW_TOKENS, |
|
streamer=streamer, |
|
use_cache=True, |
|
stopping_criteria=[stopping_criteria] |
|
) |
|
|
|
|
|
outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() |
|
|
|
|
|
return outputs |
|
|
|
|
|
def decode_base64_image(self, image_string): |
|
base64_image = base64.b64decode(image_string) |
|
buffer = BytesIO(base64_image) |
|
image = Image.open(buffer) |
|
return image |