Spaces:
Runtime error
Runtime error
# from .demo_modelpart import InferenceDemo | |
import gradio as gr | |
import os | |
from threading import Thread | |
# import time | |
import cv2 | |
import datetime | |
# import copy | |
import torch | |
import spaces | |
import numpy as np | |
from llava import conversation as conversation_lib | |
from llava.constants import DEFAULT_IMAGE_TOKEN | |
from llava.constants import ( | |
IMAGE_TOKEN_INDEX, | |
DEFAULT_IMAGE_TOKEN, | |
DEFAULT_IM_START_TOKEN, | |
DEFAULT_IM_END_TOKEN, | |
) | |
from llava.conversation import conv_templates, SeparatorStyle | |
from llava.model.builder import load_pretrained_model | |
from llava.utils import disable_torch_init | |
from llava.mm_utils import ( | |
tokenizer_image_token, | |
get_model_name_from_path, | |
KeywordsStoppingCriteria, | |
) | |
from serve_constants import html_header, bibtext, learn_more_markdown, tos_markdown | |
from decord import VideoReader, cpu | |
import requests | |
from PIL import Image | |
import io | |
from io import BytesIO | |
from transformers import TextStreamer, TextIteratorStreamer | |
import hashlib | |
import PIL | |
import base64 | |
import json | |
import datetime | |
import gradio as gr | |
import gradio_client | |
import subprocess | |
import sys | |
from huggingface_hub import HfApi | |
from huggingface_hub import login | |
from huggingface_hub import revision_exists | |
login(token=os.environ["HF_TOKEN"], | |
write_permission=True) | |
api = HfApi() | |
repo_name = os.environ["LOG_REPO"] | |
external_log_dir = "./logs" | |
LOGDIR = external_log_dir | |
VOTEDIR = "./votes" | |
def install_gradio_4_35_0(): | |
current_version = gr.__version__ | |
if current_version != "4.35.0": | |
print(f"Current Gradio version: {current_version}") | |
print("Installing Gradio 4.35.0...") | |
subprocess.check_call([sys.executable, "-m", "pip", "install", "gradio==4.35.0", "--force-reinstall"]) | |
print("Gradio 4.35.0 installed successfully.") | |
else: | |
print("Gradio 4.35.0 is already installed.") | |
# Call the function to install Gradio 4.35.0 if needed | |
install_gradio_4_35_0() | |
import gradio as gr | |
import gradio_client | |
print(f"Gradio version: {gr.__version__}") | |
print(f"Gradio-client version: {gradio_client.__version__}") | |
def get_conv_log_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(LOGDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_conv.json") | |
return name | |
def get_conv_vote_filename(): | |
t = datetime.datetime.now() | |
name = os.path.join(VOTEDIR, f"{t.year}-{t.month:02d}-{t.day:02d}-user_vote.json") | |
if not os.path.isfile(name): | |
os.makedirs(os.path.dirname(name), exist_ok=True) | |
return name | |
def vote_last_response(state, vote_type, model_selector): | |
with open(get_conv_vote_filename(), "a") as fout: | |
data = { | |
"type": vote_type, | |
"model": model_selector, | |
"state": state, | |
} | |
fout.write(json.dumps(data) + "\n") | |
api.upload_file( | |
path_or_fileobj=get_conv_vote_filename(), | |
path_in_repo=get_conv_vote_filename().replace("./votes/", ""), | |
repo_id=repo_name, | |
repo_type="dataset") | |
def upvote_last_response(state): | |
vote_last_response(state, "upvote", "MAmmoTH-VL-8b") | |
gr.Info("Thank you for your voting!") | |
return state | |
def downvote_last_response(state): | |
vote_last_response(state, "downvote", "MAmmoTH-VL-8b") | |
gr.Info("Thank you for your voting!") | |
return state | |
class InferenceDemo(object): | |
def __init__( | |
self, args, model_path, tokenizer, model, image_processor, context_len | |
) -> None: | |
disable_torch_init() | |
self.tokenizer, self.model, self.image_processor, self.context_len = ( | |
tokenizer, | |
model, | |
image_processor, | |
context_len, | |
) | |
if "llama-2" in model_name.lower(): | |
conv_mode = "llava_llama_2" | |
elif "v1" in model_name.lower(): | |
conv_mode = "llava_v1" | |
elif "mpt" in model_name.lower(): | |
conv_mode = "mpt" | |
elif "qwen" in model_name.lower(): | |
conv_mode = "qwen_1_5" | |
elif "pangea" in model_name.lower(): | |
conv_mode = "qwen_1_5" | |
elif "mammoth-vl" in model_name.lower(): | |
conv_mode = "qwen_2_5" | |
else: | |
conv_mode = "llava_v0" | |
if args.conv_mode is not None and conv_mode != args.conv_mode: | |
print( | |
"[WARNING] the auto inferred conversation mode is {}, while `--conv-mode` is {}, using {}".format( | |
conv_mode, args.conv_mode, args.conv_mode | |
) | |
) | |
else: | |
args.conv_mode = conv_mode | |
self.conv_mode = conv_mode | |
self.conversation = conv_templates[args.conv_mode].copy() | |
self.num_frames = args.num_frames | |
class ChatSessionManager: | |
def __init__(self): | |
self.chatbot_instance = None | |
def initialize_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): | |
self.chatbot_instance = InferenceDemo(args, model_path, tokenizer, model, image_processor, context_len) | |
print(f"Initialized Chatbot instance with ID: {id(self.chatbot_instance)}") | |
def reset_chatbot(self): | |
self.chatbot_instance = None | |
def get_chatbot(self, args, model_path, tokenizer, model, image_processor, context_len): | |
if self.chatbot_instance is None: | |
self.initialize_chatbot(args, model_path, tokenizer, model, image_processor, context_len) | |
return self.chatbot_instance | |
def is_valid_video_filename(name): | |
video_extensions = ["avi", "mp4", "mov", "mkv", "flv", "wmv", "mjpeg"] | |
ext = name.split(".")[-1].lower() | |
if ext in video_extensions: | |
return True | |
else: | |
return False | |
def is_valid_image_filename(name): | |
image_extensions = ["jpg", "jpeg", "png", "bmp", "gif", "tiff", "webp", "heic", "heif", "jfif", "svg", "eps", "raw"] | |
ext = name.split(".")[-1].lower() | |
if ext in image_extensions: | |
return True | |
else: | |
return False | |
def sample_frames_old(video_file, num_frames): | |
video = cv2.VideoCapture(video_file) | |
total_frames = int(video.get(cv2.CAP_PROP_FRAME_COUNT)) | |
interval = total_frames // num_frames | |
frames = [] | |
for i in range(total_frames): | |
ret, frame = video.read() | |
pil_img = Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) | |
if not ret: | |
continue | |
if i % interval == 0: | |
frames.append(pil_img) | |
video.release() | |
return frames | |
def sample_frames(video_path, frame_count=32): | |
video_frames = [] | |
vr = VideoReader(video_path, ctx=cpu(0)) | |
total_frames = len(vr) | |
frame_interval = max(total_frames // frame_count, 1) | |
for i in range(0, total_frames, frame_interval): | |
frame = vr[i].asnumpy() | |
frame_image = Image.fromarray(frame) | |
buffered = io.BytesIO() | |
frame_image.save(buffered, format="JPEG") | |
frame_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
video_frames.append(frame_base64) | |
if len(video_frames) >= frame_count: | |
break | |
# Ensure at least one frame is returned if total frames are less than required | |
if len(video_frames) < frame_count and total_frames > 0: | |
for i in range(total_frames): | |
frame = vr[i].asnumpy() | |
frame_image = Image.fromarray(frame) | |
buffered = io.BytesIO() | |
frame_image.save(buffered, format="JPEG") | |
frame_base64 = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
video_frames.append(frame_base64) | |
if len(video_frames) >= frame_count: | |
break | |
return video_frames | |
def load_image(image_file): | |
if image_file.startswith("http") or image_file.startswith("https"): | |
response = requests.get(image_file) | |
if response.status_code == 200: | |
image = Image.open(BytesIO(response.content)).convert("RGB") | |
else: | |
print("failed to load the image") | |
else: | |
print("Load image from local file") | |
print(image_file) | |
image = Image.open(image_file).convert("RGB") | |
return image | |
def clear_response(history): | |
for index_conv in range(1, len(history)): | |
# loop until get a text response from our model. | |
conv = history[-index_conv] | |
if not (conv[0] is None): | |
break | |
question = history[-index_conv][0] | |
history = history[:-index_conv] | |
return history, question | |
chat_manager = ChatSessionManager() | |
def clear_history(history): | |
chatbot_instance = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len) | |
chatbot_instance.conversation = conv_templates[chatbot_instance.conv_mode].copy() | |
return None | |
def add_message(history, message): | |
global chat_image_num | |
if not history: | |
history = [] | |
our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len) | |
chat_image_num = 0 | |
if len(message["files"]) <= 1: | |
for x in message["files"]: | |
history.append(((x,), None)) | |
chat_image_num += 1 | |
if chat_image_num > 1: | |
history = [] | |
chat_manager.reset_chatbot() | |
our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len) | |
chat_image_num = 0 | |
for x in message["files"]: | |
history.append(((x,), None)) | |
chat_image_num += 1 | |
if message["text"] is not None: | |
history.append((message["text"], None)) | |
print(f"### Chatbot instance ID: {id(our_chatbot)}") | |
return history, gr.MultimodalTextbox(value=None, interactive=False) | |
else: | |
for x in message["files"]: | |
history.append(((x,), None)) | |
if message["text"] is not None: | |
history.append((message["text"], None)) | |
return history, gr.MultimodalTextbox(value=None, interactive=False) | |
def bot(history, temperature, top_p, max_output_tokens): | |
our_chatbot = chat_manager.get_chatbot(args, model_path, tokenizer, model, image_processor, context_len) | |
print(f"### Chatbot instance ID: {id(our_chatbot)}") | |
text = history[-1][0] | |
images_this_term = [] | |
text_this_term = "" | |
num_new_images = 0 | |
# previous_image = False | |
for i, message in enumerate(history[:-1]): | |
if type(message[0]) is tuple: | |
# if previous_image: | |
# gr.Warning("Only one image can be uploaded in a conversation. Please reduce the number of images and start a new conversation.") | |
# our_chatbot.conversation = conv_templates[our_chatbot.conv_mode].copy() | |
# return None | |
images_this_term.append(message[0][0]) | |
if is_valid_video_filename(message[0][0]): | |
# raise ValueError("Video is not supported") | |
# num_new_images += our_chatbot.num_frames | |
num_new_images += len(sample_frames(message[0][0], our_chatbot.num_frames)) | |
elif is_valid_image_filename(message[0][0]): | |
print("#### Load image from local file",message[0][0]) | |
num_new_images += 1 | |
else: | |
raise ValueError("Invalid file format") | |
# previous_image = True | |
else: | |
num_new_images = 0 | |
# previous_image = False | |
image_list = [] | |
for f in images_this_term: | |
if is_valid_video_filename(f): | |
image_list += sample_frames(f, our_chatbot.num_frames) | |
elif is_valid_image_filename(f): | |
image_list.append(load_image(f)) | |
else: | |
raise ValueError("Invalid image file") | |
all_image_hash = [] | |
all_image_path = [] | |
for file_path in images_this_term: | |
with open(file_path, "rb") as file: | |
file_data = file.read() | |
file_hash = hashlib.md5(file_data).hexdigest() | |
all_file_hash.append(file_hash) | |
t = datetime.datetime.now() | |
output_dir = os.path.join( | |
LOGDIR, | |
"serve_files", | |
f"{t.year}-{t.month:02d}-{t.day:02d}" | |
) | |
os.makedirs(output_dir, exist_ok=True) | |
if is_valid_image_filename(file_path): | |
# Process and save images | |
image = Image.open(file_path).convert("RGB") | |
filename = os.path.join(output_dir, f"{file_hash}.jpg") | |
all_file_path.append(filename) | |
if not os.path.isfile(filename): | |
print("Image saved to", filename) | |
image.save(filename) | |
elif is_valid_video_filename(file_path): | |
# Simplified video saving | |
filename = os.path.join(output_dir, f"{file_hash}.mp4") | |
all_file_path.append(filename) | |
if not os.path.isfile(filename): | |
print("Video saved to", filename) | |
os.makedirs(os.path.dirname(filename), exist_ok=True) | |
# Directly copy the video file | |
with open(file_path, "rb") as src, open(filename, "wb") as dst: | |
dst.write(src.read()) | |
image_tensor = [ | |
our_chatbot.image_processor.preprocess(f, return_tensors="pt")["pixel_values"][ | |
0 | |
] | |
.half() | |
.to(our_chatbot.model.device) | |
for f in image_list | |
] | |
image_tensor = torch.stack(image_tensor) | |
image_token = DEFAULT_IMAGE_TOKEN * num_new_images | |
inp = text | |
inp = image_token + "\n" + inp | |
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[0], inp) | |
# image = None | |
our_chatbot.conversation.append_message(our_chatbot.conversation.roles[1], None) | |
prompt = our_chatbot.conversation.get_prompt() | |
input_ids = tokenizer_image_token( | |
prompt, our_chatbot.tokenizer, IMAGE_TOKEN_INDEX, return_tensors="pt" | |
).unsqueeze(0).to(our_chatbot.model.device) | |
# print("### input_id",input_ids) | |
stop_str = ( | |
our_chatbot.conversation.sep | |
if our_chatbot.conversation.sep_style != SeparatorStyle.TWO | |
else our_chatbot.conversation.sep2 | |
) | |
keywords = [stop_str] | |
stopping_criteria = KeywordsStoppingCriteria( | |
keywords, our_chatbot.tokenizer, input_ids | |
) | |
streamer = TextIteratorStreamer( | |
our_chatbot.tokenizer, skip_prompt=True, skip_special_tokens=True | |
) | |
print(our_chatbot.model.device) | |
print(input_ids.device) | |
print(image_tensor.device) | |
generate_kwargs = dict( | |
inputs=input_ids, | |
streamer=streamer, | |
images=image_tensor, | |
do_sample=True, | |
temperature=temperature, | |
top_p=top_p, | |
max_new_tokens=max_output_tokens, | |
use_cache=False, | |
stopping_criteria=[stopping_criteria], | |
) | |
t = Thread(target=our_chatbot.model.generate, kwargs=generate_kwargs) | |
t.start() | |
outputs = [] | |
for stream_token in streamer: | |
outputs.append(stream_token) | |
history[-1] = [text, "".join(outputs)] | |
yield history | |
our_chatbot.conversation.messages[-1][-1] = "".join(outputs) | |
# print("### turn end history", history) | |
# print("### turn end conv",our_chatbot.conversation) | |
with open(get_conv_log_filename(), "a") as fout: | |
data = { | |
"type": "chat", | |
"model": "MAmmoTH-VL-8b", | |
"state": history, | |
"images": all_image_hash, | |
"images_path": all_image_path | |
} | |
print("#### conv log",data) | |
fout.write(json.dumps(data) + "\n") | |
for upload_img in all_image_path: | |
api.upload_file( | |
path_or_fileobj=upload_img, | |
path_in_repo=upload_img.replace("./logs/", ""), | |
repo_id=repo_name, | |
repo_type="dataset", | |
# revision=revision, | |
# ignore_patterns=["data*"] | |
) | |
# upload json | |
api.upload_file( | |
path_or_fileobj=get_conv_log_filename(), | |
path_in_repo=get_conv_log_filename().replace("./logs/", ""), | |
repo_id=repo_name, | |
repo_type="dataset") | |
txt = gr.Textbox( | |
scale=4, | |
show_label=False, | |
placeholder="Enter text and press enter.", | |
container=False, | |
) | |
with gr.Blocks( | |
css=".message-wrap.svelte-1lcyrx4>div.svelte-1lcyrx4 img {min-width: 40px}", | |
) as demo: | |
cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
# gr.Markdown(title_markdown) | |
gr.HTML(html_header) | |
with gr.Column(): | |
with gr.Accordion("Parameters", open=False) as parameter_row: | |
temperature = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=0.7, | |
step=0.1, | |
interactive=True, | |
label="Temperature", | |
) | |
top_p = gr.Slider( | |
minimum=0.0, | |
maximum=1.0, | |
value=1, | |
step=0.1, | |
interactive=True, | |
label="Top P", | |
) | |
max_output_tokens = gr.Slider( | |
minimum=0, | |
maximum=8192, | |
value=4096, | |
step=256, | |
interactive=True, | |
label="Max output tokens", | |
) | |
with gr.Row(): | |
chatbot = gr.Chatbot([], elem_id="MAmmoTH-VL-8B", bubble_full_width=False, height=750) | |
with gr.Row(): | |
upvote_btn = gr.Button(value="👍 Upvote", interactive=True) | |
downvote_btn = gr.Button(value="👎 Downvote", interactive=True) | |
flag_btn = gr.Button(value="⚠️ Flag", interactive=True) | |
# stop_btn = gr.Button(value="⏹️ Stop Generation", interactive=True) | |
regenerate_btn = gr.Button(value="🔄 Regenerate", interactive=True) | |
clear_btn = gr.Button(value="🗑️ Clear history", interactive=True) | |
chat_input = gr.MultimodalTextbox( | |
interactive=True, | |
file_types=["image", "video"], | |
placeholder="Enter message or upload file...", | |
show_label=False, | |
submit_btn="🚀" | |
) | |
print(cur_dir) | |
gr.Examples( | |
examples_per_page=20, | |
examples=[ | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/172197131626056_P7966202.png", | |
], | |
"text": "Why this image funny?", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_doc.png", | |
], | |
"text": "Read text in the image", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_weather.jpg", | |
], | |
"text": "List the weather for Monday to Friday", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_knowledge.jpg", | |
], | |
"text": "Answer the following question based on the provided image: What country do these planes belong to?", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_math.jpg", | |
], | |
"text": "Find the measure of angle 3.", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_interact.jpg", | |
], | |
"text": "Please perfectly describe this cartoon illustration in as much detail as possible", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_perfer.jpg", | |
], | |
"text": "This is an image of a room. It could either be a real image captured in the room or a rendered image from a 3D scene reconstruction technique that is trained using real images of the room. A rendered image usually contains some visible artifacts (eg. blurred regions due to under-reconstructed areas) that do not faithfully represent the actual scene. You need to decide if its a real image or a rendered image by giving each image a photorealism score between 1 and 5.", | |
} | |
], | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_multi1.png", | |
f"{cur_dir}/examples/realcase_multi2.png", | |
f"{cur_dir}/examples/realcase_multi3.png", | |
f"{cur_dir}/examples/realcase_multi4.png", | |
f"{cur_dir}/examples/realcase_multi5.png", | |
], | |
"text": "Based on the five species in the images, draw a food chain. Explain the role of each species in the food chain.", | |
} | |
], | |
], | |
inputs=[chat_input], | |
label="Real World Image Cases", | |
) | |
gr.Examples( | |
examples=[ | |
[ | |
{ | |
"files": [ | |
f"{cur_dir}/examples/realcase_video.mp4", | |
], | |
"text": "Please describe the video in detail.", | |
}, | |
] | |
], | |
inputs=[chat_input], | |
label="Real World Video Case" | |
) | |
gr.Markdown(tos_markdown) | |
gr.Markdown(learn_more_markdown) | |
gr.Markdown(bibtext) | |
chat_input.submit( | |
add_message, [chatbot, chat_input], [chatbot, chat_input] | |
).then(bot, [chatbot, temperature, top_p, max_output_tokens], chatbot, api_name="bot_response").then(lambda: gr.MultimodalTextbox(interactive=True), None, [chat_input]) | |
# chatbot.like(print_like_dislike, None, None) | |
clear_btn.click( | |
fn=clear_history, inputs=[chatbot], outputs=[chatbot], api_name="clear_all" | |
) | |
upvote_btn.click( | |
fn=upvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response" | |
) | |
downvote_btn.click( | |
fn=downvote_last_response, inputs=chatbot, outputs=chatbot, api_name="upvote_last_response" | |
) | |
demo.queue() | |
if __name__ == "__main__": | |
import argparse | |
argparser = argparse.ArgumentParser() | |
argparser.add_argument("--server_name", default="0.0.0.0", type=str) | |
argparser.add_argument("--port", default="6123", type=str) | |
argparser.add_argument( | |
"--model_path", default="MMSFT/MAmmoTH-VL-8B", type=str | |
) | |
# argparser.add_argument("--model-path", type=str, default="facebook/opt-350m") | |
argparser.add_argument("--model-base", type=str, default=None) | |
argparser.add_argument("--num-gpus", type=int, default=1) | |
argparser.add_argument("--conv-mode", type=str, default=None) | |
argparser.add_argument("--temperature", type=float, default=0.7) | |
argparser.add_argument("--max-new-tokens", type=int, default=4096) | |
argparser.add_argument("--num_frames", type=int, default=32) | |
argparser.add_argument("--load-8bit", action="store_true") | |
argparser.add_argument("--load-4bit", action="store_true") | |
argparser.add_argument("--debug", action="store_true") | |
args = argparser.parse_args() | |
model_path = args.model_path | |
filt_invalid = "cut" | |
model_name = get_model_name_from_path(args.model_path) | |
tokenizer, model, image_processor, context_len = load_pretrained_model(args.model_path, args.model_base, model_name, args.load_8bit, args.load_4bit) | |
model=model.to(torch.device('cuda')) | |
chat_image_num = 0 | |
demo.launch() |