import sys import torch import os import random import base64 import msgpack from io import BytesIO import numpy as np from transformers import AutoTokenizer from llava.constants import MM_TOKEN_INDEX, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN, DEFAULT_VIDEO_TOKEN, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN from llava.conversation import conv_templates, SeparatorStyle from llava.utils import disable_torch_init from llava.mm_utils import tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, process_images_v2 from llava.model.builder import load_pretrained_model from llava.model.multimodal_encoder.processor import Blip2ImageTrainProcessor from llava.model import LlavaMistralForCausalLM from transformers import CLIPImageProcessor from PIL import Image import logging def select_frames(input_frames, num_segments = 10): indices = np.linspace(start=0, stop=len(input_frames)-1, num=num_segments).astype(int) frames = [input_frames[ind] for ind in indices] return frames def load_model(model_path, device_map): kwargs = {"device_map": device_map} kwargs['torch_dtype'] = torch.float32 tokenizer = AutoTokenizer.from_pretrained(model_path) model = LlavaMistralForCausalLM.from_pretrained( model_path, low_cpu_mem_usage=True, **kwargs ) tokenizer.add_tokens([DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, DEFAULT_VIDEO_START_TOKEN, DEFAULT_VIDEO_END_TOKEN], special_tokens=True) model.resize_token_embeddings(len(tokenizer)) vision_tower = model.get_vision_tower() if not vision_tower.is_loaded: vision_tower.load_model(device_map=device_map) return model, tokenizer class EndpointHandler: def __init__(self): model_path = './checkpoint-3000' disable_torch_init() model_path = os.path.expanduser(model_path) #print(model_path) model_name = get_model_name_from_path(model_path) model, tokenizer = load_model(model_path, device_map={"":0}) #tokenizer, model, _, context_len = load_pretrained_model(model_path, None, model_name, device_map={"":0}) image_processor = Blip2ImageTrainProcessor( image_size=model.config.img_size, is_training=False) """ import os from PIL import Image input_dir = './v12044gd0000clg1n4fog65p7pag5n6g/video' image_paths = os.listdir(input_dir) images = [Image.open(os.path.join(input_dir, item)) for item in image_paths] num_segments = 10 images = images[:num_segments] import torch device = torch.device('cuda:0') image_processor = Blip2ImageTrainProcessor( image_size=224, is_training=False) images_tensor = [image_processor.preprocess(image).cpu().to(device) for image in images] """ self.tokenizer = tokenizer self.device = torch.device('cpu') self.model = model.to(self.device) self.image_processor = image_processor self.conv_mode = 'v1' def inference_frames(self, images, question, temperature): if len(images) > 10: images = select_frames(images) conv_mode = self.conv_mode image_processor = self.image_processor # if isinstance(image_processor, CLIPImageProcessor): # images_tensor = [image_processor.preprocess(image, return_tensors='pt')['pixel_values'][0].cpu().to(self.device) for image in images] # else: # logging.info(f'length of images:{len(images)}') #images_tensor = [image_processor.preprocess(image).cpu() for image in images] #images_tensor = torch.stack(images_tensor, dim=0).half().to(self.device) images_tensor = process_images_v2(images, image_processor, self.model.config) images_tensor = images_tensor.to(self.device) # print(images_tensor.shape) qs = question if len(images) == 1: qs = DEFAULT_IM_START_TOKEN + DEFAULT_IMAGE_TOKEN + DEFAULT_IM_END_TOKEN + '\n' + qs else: qs = DEFAULT_VIDEO_START_TOKEN + DEFAULT_VIDEO_TOKEN + DEFAULT_VIDEO_END_TOKEN + '\n' + qs conv = conv_templates[conv_mode].copy() conv.append_message(conv.roles[0], qs) conv.append_message(conv.roles[1], None) prompt = conv.get_prompt() input_ids = tokenizer_image_token(prompt, self.tokenizer, MM_TOKEN_INDEX, return_tensors='pt').unsqueeze( 0).to(self.device) stop_str = conv.sep if conv.sep2 is None else conv.sep2 keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria(keywords, self.tokenizer, input_ids) with torch.inference_mode(): output_ids = self.model.generate( input_ids, images=[images_tensor], temperature=temperature, do_sample=True, top_p=None, num_beams=1, no_repeat_ngram_size=3, max_new_tokens=1024, use_cache=True, stopping_criteria=[stopping_criteria], ) outputs = self.tokenizer.decode(output_ids[0], skip_special_tokens=True).strip() outputs = outputs.strip() if outputs.endswith(conv.sep): outputs = outputs[:-len(stop_str)] outputs = outputs.strip() # outputs = outputs[3:-4].strip() return outputs def __call__(self, request): # Step 5: Unpack the data and convert back to PIL images packed_data= request['images'][0] unpacked_data = msgpack.unpackb(packed_data, raw=False) image_list = [Image.open(BytesIO(byte_data)) for byte_data in unpacked_data] prompt = request.get('prompt', [''.encode()])[0].decode() temperature = request.get('temperature', ['0.01'.encode()])[0].decode() temperature = float(temperature) #print(request) if prompt=='': if len(image_list) == 1: prompt = "Please describe this image in detail." else: prompt = "Please describe this video in detail." # prompt = "Describe the following video in detail." with torch.no_grad(): outputs = self.inference_frames(image_list, prompt, temperature) return {'output': [outputs]} if __name__ == "__main__": video_dir = '/mnt/bn/yukunfeng-nasdrive/xiangchen/masp_data/20231110_ttp/video/v12044gd0000cl5c6rfog65i2eoqcqig' frames = [(int(os.path.splitext(item)[0]), os.path.join(video_dir, item)) for item in os.listdir(video_dir)] frames = [item[1] for item in sorted(frames, key=lambda x: x[0])] out_frames = [Image.open(frame).convert('RGB') for frame in frames] # out_frames = select_frames(frames) request = {} # Step 3: Convert images to byte format byte_images = [] for img in out_frames: byte_io = BytesIO() img.save(byte_io, format='JPEG') byte_images.append(byte_io.getvalue()) # Step 4: Pack the byte data with msgpack packed_data = msgpack.packb(byte_images) request['images'] = [packed_data] # request['temperature'] = ['0.2'.encode()] request['temperature'] = ['0.01'.encode()] # request['prompt'] = ['describe the image in detail'.encode()] #new_request = {} #new_request['0'] = request['2'] handler = EndpointHandler() print(handler(request))