import gradio as gr import os from threading import Thread from queue import Queue import time import cv2 import datetime import torch import spaces import numpy as np import json import hashlib import PIL from typing import Iterator from llava import conversation as conversation_lib from llava.constants import DEFAULT_IMAGE_TOKEN from llava.constants import ( IMAGE_TOKEN_INDEX, DEFAULT_IMAGE_TOKEN, DEFAULT_IM_START_TOKEN, DEFAULT_IM_END_TOKEN, ) from llava.conversation import conv_templates, SeparatorStyle from llava.model.builder import load_pretrained_model from llava.utils import disable_torch_init from llava.mm_utils import ( tokenizer_image_token, get_model_name_from_path, KeywordsStoppingCriteria, ) from serve_constants import html_header import requests from PIL import Image from io import BytesIO from transformers import TextIteratorStreamer external_log_dir = "./logs" LOGDIR = external_log_dir def install_gradio_4_35_0(): current_version = gr.__version__ if current_version != "4.35.0": print(f"Current Gradio version: {current_version}") print("Installing Gradio 4.35.0...") subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"]) print("Gradio 4.35.0 installed successfully.") else: print("Gradio 4.35.0 is already installed.") install_gradio_4_35_0() print(f"Gradio version: {gr.__version__}") def get_conv_log_filename(): t = datetime.datetime.now() name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") return name class InferenceDemo(object): def __init__( self, args, model_path, tokenizer, model, image_processor, context_len ) -> None: disable_torch_init() self.tokenizer = tokenizer self.model = model self.image_processor = image_processor self.context_len = context_len model_name = get_model_name_from_path(model_path) if "llama-2" in model_name.lower(): conv_mode = "llava_llama_2" elif "v1" in model_name.lower(): conv_mode = "llava_v1" elif "mpt" in model_name.lower(): conv_mode = "mpt" elif "qwen" in model_name.lower(): conv_mode = "qwen_1_5" elif "pangea" in model_name.lower(): conv_mode = "qwen_1_5" else: conv_mode = "llava_v0" if args.conv_mode is not None and conv_mode != args.conv_mode: print( "[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( conv_mode, args.conv_mode, args.conv_mode ) ) else: args.conv_mode = conv_mode self.conv_mode = conv_mode self.conversation = conv_templates[args.conv_mode].copy() self.num_frames = args.num_frames def process_stream(streamer: TextIteratorStreamer, history: list, q: Queue): """Process the output stream and put partial text into a queue""" try: current_message = "" for new_text in streamer: current_message += new_text history[-1][1] = current_message q.put(history.copy()) time.sleep(0.02) # Add a small delay to prevent overloading except Exception as e: print(f"Error in process_stream: {e}") finally: q.put(None) # Signal that we're done def stream_output(history: list, q: Queue) -> Iterator[list]: """Yield updated history as it comes through the queue""" while True: val = q.get() if val is None: break yield val q.task_done() def is_valid_video_filename(name): video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] ext = name.split(".")[-1].lower() return ext in video_extensions def is_valid_image_filename(name): image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] ext = name.split(".")[-1].lower() return ext in image_extensions def sample_frames(video_file, num_frames): video = cv2.VideoCapture(video_file) total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) interval = total_frames // num_frames frames = [] for i in range(total_frames): ret, frame = video.read() if not ret: continue pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) if i % interval == 0: frames.append(pil_img) video.release() return frames def load_image(image_file): if image_file.startswith(("http://", "https://")): response = requests.get(image_file) if response.status_code == 200: image = Image.open(BytesIO(response.content)).convert("RGB") else: print("Failed to load the image") return None else: print("Load image from local file:", image_file) image = Image.open(image_file).convert("RGB") return image def clear_history(history): global our_chatbot our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy() return None def add_message(history, message): global our_chatbot if len(history) == 0: our_chatbot = InferenceDemo( args, model_path, tokenizer, model, image_processor, context_len ) for x in message["files"]: history.append(((x,), None)) if message["text"] is not None: history.append((message["text"], None)) return history, gr.MultimodalTextbox(value=None, interactive=False) @spaces.GPU def bot(history): global start_tstamp, finish_tstamp start_tstamp = time.time() text = history[-1][0] images_this_term = [] num_new_images = 0 for i, message in enumerate(history[:-1]): if isinstance(message[0], tuple): images_this_term.append(message[0][0]) if is_valid_video_filename(message[0][0]): raise ValueError("Video is not supported") elif is_valid_image_filename(message[0][0]): num_new_images += 1 else: raise ValueError("Invalid image file") else: num_new_images = 0 assert len(images_this_term) > 0, "Must have an image" image_list = [] for f in images_this_term: if is_valid_video_filename(f): image_list += sample_frames(f, our_chatbot.num_frames) elif is_valid_image_filename(f): image_list.append(load_image(f)) else: raise ValueError("Invalid image file") image_tensor = [ our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][0] .half() .to(our_chatbot.model.device) for f in image_list ] # Process image hashes all_image_hash = [] for image_path in images_this_term: with open(image_path, "rb") as image_file: image_data = image_file.read() image_hash = hashlib.md5(image_data).hexdigest() all_image_hash.append(image_hash) image = PIL.Image.open(image_path).convert("RGB") t = datetime.datetime.now() filename = os.path.join( LOGDIR, "serve_images", f"{t.year}-{t.month:02d}-{t.day:02d}", f"{image_hash}.jpg", ) if not os.path.isfile(filename): os.makedirs(os.path.dirname(filename), exist_ok=True) image.save(filename) image_tensor = torch.stack(image_tensor) image_token = DEFAULT_IMAGE_TOKEN * num_new_images inp = image_token + "\n" + text our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) prompt = our_chatbot.conversation.get_prompt() input_ids = ( tokenizer_image_token( prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" ) .unsqueeze(0) .to(our_chatbot.model.device) ) stop_str = ( our_chatbot.conversation.sep if our_chatbot.conversation.sep_style != SeparatorStyle.TWO else our_chatbot.conversation.sep2 ) keywords = [stop_str] stopping_criteria = KeywordsStoppingCriteria( keywords, our_chatbot.tokenizer, input_ids ) # Set up streaming q = Queue() streamer = TextIteratorStreamer( our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True ) # Start generation in a separate thread thread = Thread( target=process_stream, args=(streamer, history, q) ) thread.start() # Start the generation with torch.inference_mode(): output_ids = our_chatbot.model.generate( input_ids, images=image_tensor, do_sample=True, temperature=0.2, max_new_tokens=1024, streamer=streamer, use_cache=True, stopping_criteria=[stopping_criteria], ) finish_tstamp = time.time() # Log conversation with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(finish_tstamp, 4), "type": "chat", "model": "Pangea-7b", "start": round(start_tstamp, 4), "finish": round(finish_tstamp, 4), "state": history, "images": all_image_hash, } fout.write(json.dumps(data) + "\n") # Return a generator that will yield updated history return stream_output(history, q) with gr.Blocks(css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}") as demo: gr.HTML(html_header) with gr.Column(): with gr.Row(): chatbot = gr.Chatbot([], elem_id="Pangea", bubble_full_width=False, height=750) with gr.Row(): upvote_btn = gr.Button(value="👍 Upvote", interactive=True) downvote_btn = gr.Button(value="👎 Downvote", interactive=True) flag_btn = gr.Button(value="⚠️ Flag", interactive=True) regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) chat_input = gr.MultimodalTextbox( interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False, submit_btn="🚀" ) cur_dir = os.path.dirname(os.path.abspath(__file__)) gr.Examples( examples_per_page=20, examples=[ [ { "files": [ f"{cur_dir}/examples/user_example_07.jpg", ], "text": "那要我问问你,你这个是什么🐱?", }, ], [ { "files": [ f"{cur_dir}/examples/user_example_05.jpg", ], "text": "この猫の目の大きさは、どのような理由で他の猫と比べて特に大きく見えますか?", }, ], [ { "files": [ f"{cur_dir}/examples/172197131626056_P7966202.png", ], "text": "Why this image funny?", }, ], ], inputs=[chat_input], label="Image", ) chat_msg = chat_input.submit( add_message, [chatbot, chat_input], [chatbot, chat_input], queue=False ).then( bot, chatbot, chatbot, api_name="bot_response" ).then( lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input] ) clear_btn.click( fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all", queue=False ) regenerate_btn.click( fn=lambda history: history[:-1], inputs=[chatbot], outputs=[chatbot], queue=False ).then( bot, chatbot, chatbot ) demo.queue(concurrency_count=5) if __name__ == "__main__": import argparse argparser = argparse.ArgumentParser() argparser.add_argument("--server_name", default="0.0.0.0", type=str) argparser.add_argument("--port", default="6123", type=str) argparser.add_argument( "--model_path", default="neulab/Pangea-7B", type=str ) # argparser.add_argument("--model-path", type=str, default="facebook/opt-350m") argparser.add_argument("--model-base", type=str, default=None) argparser.add_argument("--num-gpus", type=int, default=1) argparser.add_argument("--conv-mode", type=str, default=None) argparser.add_argument("--temperature", type=float, default=0.7) argparser.add_argument("--max-new-tokens", type=int, default=4096) argparser.add_argument("--num_frames", type=int, default=16) argparser.add_argument("--load-8bit", action="store_true") argparser.add_argument("--load-4bit", action="store_true") argparser.add_argument("--debug", action="store_true") args = argparser.parse_args() model_path = args.model_path filt_invalid = "cut" model_name = get_model_name_from_path(args.model_path) tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) model=model.to(torch.device('cuda')) our_chatbot = None demo.launch()