import argparse import time import logging import requests import os from PIL import Image from io import BytesIO from PIL import Image import torch from transformers import AutoTokenizer from transformers import AutoTokenizer, AutoModelForCausalLM from PIL import Image from io import BytesIO import base64 import torch from transformers import StoppingCriteria import math import ast # Model Constants IGNORE_INDEX = -100 IMAGE_TOKEN_INDEX = -200 DEFAULT_IMAGE_TOKEN = "" DEFAULT_IMAGE_PATCH_TOKEN = "" DEFAULT_IM_START_TOKEN = "" DEFAULT_IM_END_TOKEN = "" IMAGE_PLACEHOLDER = "" import dataclasses from enum import auto, Enum from typing import List, Tuple class SeparatorStyle(Enum): """Different separator style.""" SINGLE = auto() TWO = auto() MPT = auto() PLAIN = auto() LLAMA_2 = auto() TINY_LLAMA = auto() QWEN_2 = auto() @dataclasses.dataclass class Conversation: """A class that keeps all conversation history.""" system: str roles: List[str] messages: List[List[str]] offset: int sep_style: SeparatorStyle = SeparatorStyle.SINGLE sep: str = "###" sep2: str = None version: str = "Unknown" skip_next: bool = False def get_prompt(self): messages = self.messages if len(messages) > 0 and type(messages[0][1]) is tuple: messages = self.messages.copy() init_role, init_msg = messages[0].copy() init_msg = init_msg[0].replace("", "").strip() if 'mmtag' in self.version: messages[0] = (init_role, init_msg) messages.insert(0, (self.roles[0], "")) messages.insert(1, (self.roles[1], "Received.")) else: messages[0] = (init_role, "\n" + init_msg) if self.sep_style == SeparatorStyle.SINGLE: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + self.sep else: ret += role + ":" elif self.sep_style == SeparatorStyle.TWO: seps = [self.sep, self.sep2] ret = self.system + seps[0] for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += role + ": " + message + seps[i % 2] else: ret += role + ":" elif self.sep_style == SeparatorStyle.MPT: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.LLAMA_2: wrap_sys = lambda msg: f"<>\n{msg}\n<>\n\n" if len(msg) > 0 else msg wrap_inst = lambda msg: f"[INST] {msg} [/INST]" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i == 0: message = wrap_sys(self.system) + message if i % 2 == 0: message = wrap_inst(message) ret += self.sep + message else: ret += " " + message + " " + self.sep2 else: ret += "" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.TINY_LLAMA: sep = "" wrap_sys = lambda msg: f"<|system|>\n{msg}\n" wrap_user = lambda msg: f"<|user|>\n{msg}\n" wrap_assistant = lambda msg: f"<|assistant|>\n{msg}" ret = "" for i, (role, message) in enumerate(messages): if i == 0: assert message, "first message should not be none" assert role == self.roles[0], "first message should come from user" if message: if type(message) is tuple: message, _, _ = message if i % 2 == 0: message = wrap_user(message) if i == 0: message = wrap_sys(self.system) + message ret += self.sep + message else: message = wrap_assistant(message) + self.sep2 ret += message else: ret += "<|assistant|>\n" ret = ret.lstrip(self.sep) elif self.sep_style == SeparatorStyle.QWEN_2: ret = self.system + self.sep for role, message in messages: if message: if type(message) is tuple: message, _, _ = message ret += role + message + self.sep else: ret += role elif self.sep_style == SeparatorStyle.PLAIN: seps = [self.sep, self.sep2] ret = self.system for i, (role, message) in enumerate(messages): if message: if type(message) is tuple: message, _, _ = message ret += message + seps[i % 2] else: ret += "" else: raise ValueError(f"Invalid style: {self.sep_style}") return ret def append_message(self, role, message): self.messages.append([role, message]) def get_images(self, return_pil=False): images = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO from PIL import Image msg, image, image_process_mode = msg if image_process_mode == "Pad": def expand2square(pil_img, background_color=(122, 116, 104)): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result image = expand2square(image) elif image_process_mode in ["Default", "Crop"]: pass elif image_process_mode == "Resize": image = image.resize((336, 336)) else: raise ValueError(f"Invalid image_process_mode: {image_process_mode}") max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if longest_edge != max(image.size): if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) if return_pil: images.append(image) else: buffered = BytesIO() image.save(buffered, format="PNG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() images.append(img_b64_str) return images def to_gradio_chatbot(self): ret = [] for i, (role, msg) in enumerate(self.messages[self.offset:]): if i % 2 == 0: if type(msg) is tuple: import base64 from io import BytesIO msg, image, image_process_mode = msg max_hw, min_hw = max(image.size), min(image.size) aspect_ratio = max_hw / min_hw max_len, min_len = 800, 400 shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw)) longest_edge = int(shortest_edge * aspect_ratio) W, H = image.size if H > W: H, W = longest_edge, shortest_edge else: H, W = shortest_edge, longest_edge image = image.resize((W, H)) buffered = BytesIO() image.save(buffered, format="JPEG") img_b64_str = base64.b64encode(buffered.getvalue()).decode() img_str = f'user upload image' msg = img_str + msg.replace('', '').strip() ret.append([msg, None]) else: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def copy(self): return Conversation( system=self.system, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, version=self.version) def dict(self): if len(self.get_images()) > 0: return { "system": self.system, "roles": self.roles, "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages], "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } return { "system": self.system, "roles": self.roles, "messages": self.messages, "offset": self.offset, "sep": self.sep, "sep2": self.sep2, } conv_phi_v0 = Conversation( system="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), version="phi", messages=(), offset=0, sep_style=SeparatorStyle.TWO, sep=" ", sep2="<|endoftext|>", ) def select_best_resolution(original_size, possible_resolutions): """ Selects the best resolution from a list of possible resolutions based on the original size. Args: original_size (tuple): The original size of the image in the format (width, height). possible_resolutions (list): A list of possible resolutions in the format [(width1, height1), (width2, height2), ...]. Returns: tuple: The best fit resolution in the format (width, height). """ original_width, original_height = original_size best_fit = None max_effective_resolution = 0 min_wasted_resolution = float('inf') for width, height in possible_resolutions: scale = min(width / original_width, height / original_height) downscaled_width, downscaled_height = int(original_width * scale), int(original_height * scale) effective_resolution = min(downscaled_width * downscaled_height, original_width * original_height) wasted_resolution = (width * height) - effective_resolution if effective_resolution > max_effective_resolution or (effective_resolution == max_effective_resolution and wasted_resolution < min_wasted_resolution): max_effective_resolution = effective_resolution min_wasted_resolution = wasted_resolution best_fit = (width, height) return best_fit ## added by llava-1.6 def resize_and_pad_image(image, target_resolution): """ Resize and pad an image to a target resolution while maintaining aspect ratio. Args: image (PIL.Image.Image): The input image. target_resolution (tuple): The target resolution (width, height) of the image. Returns: PIL.Image.Image: The resized and padded image. """ original_width, original_height = image.size target_width, target_height = target_resolution scale_w = target_width / original_width scale_h = target_height / original_height if scale_w < scale_h: new_width = target_width new_height = min(math.ceil(original_height * scale_w), target_height) else: new_height = target_height new_width = min(math.ceil(original_width * scale_h), target_width) # Resize the image resized_image = image.resize((new_width, new_height)) new_image = Image.new('RGB', (target_width, target_height), (0, 0, 0)) paste_x = (target_width - new_width) // 2 paste_y = (target_height - new_height) // 2 new_image.paste(resized_image, (paste_x, paste_y)) return new_image ## added by llava-1.6 def divide_to_patches(image, patch_size): """ Divides an image into patches of a specified size. Args: image (PIL.Image.Image): The input image. patch_size (int): The size of each patch. Returns: list: A list of PIL.Image.Image objects representing the patches. """ patches = [] width, height = image.size for i in range(0, height, patch_size): for j in range(0, width, patch_size): box = (j, i, j + patch_size, i + patch_size) patch = image.crop(box) patches.append(patch) return patches ## added by llava-1.6 def get_anyres_image_grid_shape(image_size, grid_pinpoints, patch_size): """ Calculate the shape of the image patch grid after the preprocessing for images of any resolution. Args: image_size (tuple): The size of the input image in the format (width, height). grid_pinpoints (str): A string representation of a list of possible resolutions. patch_size (int): The size of each image patch. Returns: tuple: The shape of the image patch grid in the format (width, height). """ if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) width, height = select_best_resolution(image_size, possible_resolutions) return width // patch_size, height // patch_size ## added by llava-1.6 def process_anyres_image(image, processor, grid_pinpoints): """ Process an image with variable resolutions. Args: image (PIL.Image.Image): The input image to be processed. processor: The image processor object. grid_pinpoints (str): A string representation of a list of possible resolutions. Returns: torch.Tensor: A tensor containing the processed image patches. """ if type(grid_pinpoints) is list: possible_resolutions = grid_pinpoints else: possible_resolutions = ast.literal_eval(grid_pinpoints) best_resolution = select_best_resolution(image.size, possible_resolutions) image_padded = resize_and_pad_image(image, best_resolution) patches = divide_to_patches(image_padded, processor.crop_size['height']) image_original_resize = image.resize((processor.size['shortest_edge'], processor.size['shortest_edge'])) image_patches = [image_original_resize] + patches image_patches = [processor.preprocess(image_patch, return_tensors='pt')['pixel_values'][0] for image_patch in image_patches] return torch.stack(image_patches, dim=0) def load_image_from_base64(image): return Image.open(BytesIO(base64.b64decode(image))) def expand2square(pil_img, background_color): width, height = pil_img.size if width == height: return pil_img elif width > height: result = Image.new(pil_img.mode, (width, width), background_color) result.paste(pil_img, (0, (width - height) // 2)) return result else: result = Image.new(pil_img.mode, (height, height), background_color) result.paste(pil_img, ((height - width) // 2, 0)) return result def process_images(images, image_processor, model_cfg): image_aspect_ratio = getattr(model_cfg, "image_aspect_ratio", None) new_images = [] if image_aspect_ratio == 'pad': for image in images: image = expand2square(image, tuple(int(x*255) for x in image_processor.image_mean)) image = image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0] new_images.append(image) elif image_aspect_ratio == "anyres": for image in images: image = process_anyres_image(image, image_processor, model_cfg.image_grid_pinpoints) new_images.append(image) else: return image_processor(images, return_tensors='pt')['pixel_values'] if all(x.shape == new_images[0].shape for x in new_images): new_images = torch.stack(new_images, dim=0) return new_images def tokenizer_image_token(prompt, tokenizer, image_token_index=IMAGE_TOKEN_INDEX, return_tensors=None): prompt_chunks = [tokenizer(chunk).input_ids for chunk in prompt.split('')] def insert_separator(X, sep): return [ele for sublist in zip(X, [sep]*len(X)) for ele in sublist][:-1] input_ids = [] offset = 0 if len(prompt_chunks) > 0 and len(prompt_chunks[0]) > 0 and prompt_chunks[0][0] == tokenizer.bos_token_id: offset = 1 input_ids.append(prompt_chunks[0][0]) for x in insert_separator(prompt_chunks, [image_token_index] * (offset + 1)): input_ids.extend(x[offset:]) if return_tensors is not None: if return_tensors == 'pt': return torch.tensor(input_ids, dtype=torch.long) raise ValueError(f'Unsupported tensor type: {return_tensors}') return input_ids def get_model_name_from_path(model_path): model_path = model_path.strip("/") model_paths = model_path.split("/") if model_paths[-1].startswith('checkpoint-'): return model_paths[-2] + "_" + model_paths[-1] else: return model_paths[-1] class KeywordsStoppingCriteria(StoppingCriteria): def __init__(self, keywords, tokenizer, input_ids): self.keywords = keywords self.keyword_ids = [] self.max_keyword_len = 0 for keyword in keywords: cur_keyword_ids = tokenizer(keyword).input_ids if len(cur_keyword_ids) > 1 and cur_keyword_ids[0] == tokenizer.bos_token_id: cur_keyword_ids = cur_keyword_ids[1:] if len(cur_keyword_ids) > self.max_keyword_len: self.max_keyword_len = len(cur_keyword_ids) self.keyword_ids.append(torch.tensor(cur_keyword_ids)) self.tokenizer = tokenizer self.start_len = input_ids.shape[1] def call_for_batch(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: offset = min(output_ids.shape[1] - self.start_len, self.max_keyword_len) self.keyword_ids = [keyword_id.to(output_ids.device) for keyword_id in self.keyword_ids] for keyword_id in self.keyword_ids: if (output_ids[0, -keyword_id.shape[0]:] == keyword_id).all(): return True outputs = self.tokenizer.batch_decode(output_ids[:, -offset:], skip_special_tokens=True)[0] for keyword in self.keywords: if keyword in outputs: return True return False def __call__(self, output_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs) -> bool: outputs = [] for i in range(output_ids.shape[0]): outputs.append(self.call_for_batch(output_ids[i].unsqueeze(0), scores)) return all(outputs) def load_image(image_file): if image_file.startswith("http") or image_file.startswith("https"): response = requests.get(image_file) image = Image.open(BytesIO(response.content)).convert("RGB") else: image = Image.open(image_file).convert("RGB") return image def generate( prompt: str, model: str, tokenizer = None, image: str = None, device: str = None, max_new_tokens: int = 1024, num_beams = 1, top_p=None, temperature=0.2 ): if not device: if torch.cuda.is_available() and torch.cuda.device_count(): device = "cuda:0" logging.warning( 'inference device is not set, using cuda:0, %s', torch.cuda.get_device_name(0) ) else: device = 'cpu' logging.warning( ( 'No CUDA device detected, using cpu, ' 'expect slower speeds.' ) ) if 'cuda' in device and not torch.cuda.is_available(): raise ValueError('CUDA device requested but no CUDA device detected.') if isinstance(model, str): checkpoint_path = model # print(f'loading model from {checkpoint_path}...') model = AutoModelForCausalLM.from_pretrained( checkpoint_path, trust_remote_code=True ) # print('model load over') config = model.config if tokenizer is None: tokenizer = AutoTokenizer.from_pretrained(model, use_fast=False, model_max_length = config.tokenizer_model_max_length, padding_side = config.tokenizer_padding_side) image_processor = model.vision_tower._image_processor context_len = getattr(config, 'max_sequence_length', 2048) model.to(device).eval() if image is not None: prompt = DEFAULT_IMAGE_TOKEN + '\n' + prompt conv = conv_phi_v0.copy() conv.append_message(conv.roles[0], prompt) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() if image is not None: # print('loading image...') image = load_image(image) # print('load image over') image_tensor = process_images(image, image_processor, config).to(model.device, dtype=torch.float16) input_ids = ( tokenizer_image_token(prompt, tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt") .unsqueeze(0) .to(model.device, dtype=torch.float16) ) # Generate stime = time.time() # stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2 # keywords = [stop_str] # stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids) # print('start inference...') with torch.inference_mode(): output_ids = model.generate( input_ids, images=image_tensor, do_sample=True if temperature > 0 else False, temperature=temperature, top_p=top_p, num_beams=num_beams, pad_token_id=tokenizer.pad_token_id, max_new_tokens=max_new_tokens, use_cache=True, # stopping_criteria=[stopping_criteria], ) # print('inference over') generation_time = time.time() - stime outputs = tokenizer.batch_decode( output_ids, skip_special_tokens=True )[0] # outputs = outputs.strip() # if outputs.endswith(stop_str): # outputs = outputs[: -len(stop_str)] outputs = outputs.strip() return outputs, generation_time def tinyllava_phi_generate_parser(): """Argument Parser""" class KwargsParser(argparse.Action): """Parser action class to parse kwargs of form key=value""" def __call__(self, parser, namespace, values, option_string=None): setattr(namespace, self.dest, dict()) for val in values: if '=' not in val: raise ValueError( ( 'Argument parsing error, kwargs are expected in' ' the form of key=value.' ) ) kwarg_k, kwarg_v = val.split('=') try: converted_v = int(kwarg_v) except ValueError: try: converted_v = float(kwarg_v) except ValueError: converted_v = kwarg_v getattr(namespace, self.dest)[kwarg_k] = converted_v parser = argparse.ArgumentParser('TinyLLaVA-Phi Generate Module') parser.add_argument( '--model', dest='model', help='Path to the hf converted model.', required=True, type=str, ) parser.add_argument( '--prompt', dest='prompt', help='Prompt for LLM call.', default='', type=str, ) parser.add_argument( '--device', dest='device', help='Device used for inference.', type=str, ) parser.add_argument("--image", type=str, default=None) parser.add_argument("--temperature", type=float, default=0) parser.add_argument("--top_p", type=float, default=None) parser.add_argument("--num_beams", type=int, default=1) parser.add_argument("--max_new_tokens", type=int, default=512) return parser.parse_args() if __name__ == '__main__': args = tinyllava_phi_generate_parser() output_text, genertaion_time = generate( prompt=args.prompt, image=args.image, model=args.model, device=args.device, max_new_tokens = args.max_new_tokens, num_beams = args.num_beams, top_p=args.top_p, temperature=args.temperature ) print_txt = ( f'\r\n{"=" * os.get_terminal_size().columns}\r\n' '\033[1m Prompt + Generated Output\033[0m\r\n' f'{"-" * os.get_terminal_size().columns}\r\n' f'{output_text}\r\n' f'{"-" * os.get_terminal_size().columns}\r\n' '\r\nGeneration took' f'\033[1m\033[92m {round(genertaion_time, 2)} \033[0m' 'seconds.\r\n' ) print(print_txt)