import argparse import torch from typing import Dict, List, Any from llava.constants import IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import process_images, tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria from configs import * from PIL import Image import requests import base64 from PIL import Image from io import BytesIO from transformers import TextStreamer class EndpointHandler(): def __init__(self, path = MODEL_PATH): disable_torch_init() self.model_path = MODEL_PATH self.model_base = MODEL_BASE self.model_name = get_model_name_from_path(self.model_path) self.device = "cuda" if torch.cuda.is_available() else "cpu" self.tokenizer, self.model, self.image_processor, context_len = load_pretrained_model( model_path=self.model_path, model_name=self.model_name, load_8bit=LOAD_8BIT, load_4bit=LOAD_4BIT, device=self.device, ) if "llama-2" in self.model_name.lower(): self.conv_mode = "llava_llama_2" elif "v1" in self.model_name.lower(): self.conv_mode = "llava_v1" elif "mpt" in self.model_name.lower(): self.conv_mode = "mpt" else: self.conv_mode = "llava_v0" # conv_mode = CONV_MODE # self.conv = conv_templates[conv_mode].copy() # if "mpt" in self.model_name.lower(): # self.roles = ("user", "assistant") # else: # self.roles = self.conv.roles def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: self.conv = conv_templates[self.conv_mode].copy() if "mpt" in self.model_name.lower(): self.roles = ("user", "assistant") else: self.roles = self.conv.roles # getting encoded image from the data image_encoded = data.pop("inputs", data) # getting the manual prompt from the data text = data["text"] # decoding the base64 to image image = self.decode_base64_image(image_encoded) # converting the mode of the image to RGB if it is not that if image.mode != "RGB": image = image.convert("RGB") model_config = {"image_aspect_ratio": IMAGE_ASPECT_RATIO} # preprocessing the image image_tensor = process_images([image], self.image_processor, model_config) # converting to torch.tensor image_tensor = image_tensor.to(self.model.device, dtype = torch.float16) while True: # getting the predefined prompt from the `prompts` file inp = text #prompt_.user_prompt if image is not None: # first message if self.model.config.mm_use_im_start_end: inp = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + inp else: inp = DEFAULT_IMAGE_TOKEN + '\n' + inp self.conv.append_message(self.conv.roles[0], inp) image = None else: # later messages self.conv.append_message(self.conv.roles[0], inp) self.conv.append_message(self.conv.roles[1], None) prompt = self.conv.get_prompt() input_ids = tokenizer_image_token(prompt, self.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt").unsqueeze(0).cuda() stop_str = self.conv.sep if self.conv.sep_style != SeparatorStyle.TWO else self.conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) streamer = TextStreamer(self.tokenizer, skip_prompt=True, skip_special_tokens=True) with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=image_tensor, do_sample=True, temperature=TEMPERATURE, max_new_tokens=MAX_NEW_TOKENS, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria] ) # print(len(output_ids) if type(output_ids) is list else output_ids.shape) outputs = self.tokenizer.decode(output_ids[0, input_ids.shape[1]:]).strip() # self.conv.messages[-1][-1] = outputs return outputs # return f"{input_ids.shape},{output_ids.shape}" def decode_base64_image(self, image_string): base64_image = base64.b64decode(image_string) buffer = BytesIO(base64_image) image = Image.open(buffer) return image