|
import gradio as gr |
|
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer |
|
from threading import Thread |
|
import re |
|
import time |
|
from PIL import Image |
|
import torch |
|
import spaces |
|
import subprocess |
|
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) |
|
|
|
import requests |
|
import torch |
|
from PIL import Image |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
|
|
from decord import VideoReader |
|
from decord import cpu |
|
from PIL import Image |
|
import numpy as np |
|
|
|
def load_video(video_path, frames=32): |
|
""" |
|
Load a video and extract a specified number of frames as PIL.Image objects. |
|
|
|
Parameters: |
|
- video_path (str): Path to the video file. |
|
- frames (int): Number of frames to extract. |
|
|
|
Returns: |
|
- List[PIL.Image]: A list of PIL.Image objects for the extracted frames. |
|
""" |
|
|
|
vr = VideoReader(video_path, ctx=cpu()) |
|
total_frames = len(vr) |
|
|
|
|
|
frame_indices = np.linspace(0, total_frames - 1, frames, dtype=int) |
|
|
|
|
|
images = [] |
|
for idx in frame_indices: |
|
frame = vr[idx] |
|
image = Image.fromarray(frame.asnumpy()) |
|
images.append(image) |
|
|
|
return images |
|
|
|
model_id_or_path = "teowu/Aria-Chat-Preview" |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id_or_path, device_map="auto", torch_dtype=torch.bfloat16, |
|
trust_remote_code=True) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id_or_path, trust_remote_code=True) |
|
|
|
@spaces.GPU |
|
def model_inference( |
|
input_dict, history, decoding_strategy, temperature, max_new_tokens, top_p |
|
): |
|
text = input_dict["text"] |
|
print(input_dict["files"]) |
|
if len(input_dict["files"]) > 1: |
|
images = [Image.open(image).convert("RGB") for image in input_dict["files"]] |
|
elif len(input_dict["files"]) == 1: |
|
if input_dict["files"][0].endswith(".mp4") or input_dict["files"][0].endswith(".avi"): |
|
images = load_video(input_dict["files"][0]) |
|
else: |
|
images = [Image.open(input_dict["files"][0]).convert("RGB")] |
|
else: |
|
images = [] |
|
|
|
|
|
if text == "" and not images: |
|
gr.Error("Please input a query and optionally image(s).") |
|
|
|
if text == "" and images: |
|
text = "Please provide a detailed description." |
|
|
|
|
|
|
|
|
|
|
|
resulting_messages = [ |
|
{ |
|
"role": "user", |
|
"content": [{"type": "image", "text": None} for _ in range(len(images))] + [ |
|
{"type": "text", "text": "\n" + text} |
|
] |
|
} |
|
] |
|
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True) |
|
inputs = processor(text=prompt, images=images, return_tensors="pt") |
|
inputs = {k: v.to("cuda") for k, v in inputs.items()} |
|
generation_args = { |
|
"max_new_tokens": max_new_tokens, |
|
"repetition_penalty": repetition_penalty, |
|
} |
|
|
|
assert decoding_strategy in [ |
|
"Greedy", |
|
"Top P Sampling", |
|
] |
|
if decoding_strategy == "Greedy": |
|
generation_args["do_sample"] = False |
|
elif decoding_strategy == "Top P Sampling": |
|
generation_args["temperature"] = temperature |
|
generation_args["do_sample"] = True |
|
generation_args["top_p"] = top_p |
|
|
|
generation_args.update(inputs) |
|
|
|
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens= True) |
|
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens) |
|
generated_text = "" |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_args) |
|
thread.start() |
|
|
|
yield "..." |
|
buffer = "" |
|
|
|
|
|
for new_text in streamer: |
|
|
|
buffer += new_text |
|
generated_text_without_prompt = buffer |
|
time.sleep(0.01) |
|
yield buffer |
|
|
|
|
|
examples=[ |
|
[{"text": "What art era do these artpieces belong to?", "files": ["example_images/rococo.jpg", "example_images/rococo_1.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
[{"text": "I'm planning a visit to this temple, give me travel tips.", "files": ["example_images/examples_wat_arun.jpg"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
[{"text": "What is the due date and the invoice date?", "files": ["example_images/examples_invoice.png"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
[{"text": "What is this UI about?", "files": ["example_images/s2w_example.png"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
[{"text": "Where do the severe droughts happen according to this diagram?", "files": ["example_images/examples_weather_events.png"]}, "Greedy", 0.4, 512, 1.2, 0.8], |
|
] |
|
demo = gr.ChatInterface(fn=model_inference, title="Aria-Chat: Improved Real-world Abilties for Open-source LMMs on Images and Videos", |
|
description="Play with [rhymes-ai/Aria-Chat-Preview](https://huggingface.co/rhymes-ai/Aria-Chat-Preview) in this demo. To get started, upload an image (or a video) and text or try one of the examples. This checkpoint works best with single turn conversations, so clear the conversation after a single turn.", |
|
examples=examples, |
|
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True, |
|
additional_inputs=[gr.Radio(["Top P Sampling", |
|
"Greedy"], |
|
value="Greedy", |
|
label="Decoding strategy", |
|
|
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
|
|
), gr.Slider( |
|
minimum=0.0, |
|
maximum=5.0, |
|
value=0.4, |
|
step=0.1, |
|
interactive=True, |
|
label="Sampling temperature", |
|
info="Higher values will produce more diverse outputs.", |
|
), |
|
gr.Slider( |
|
minimum=8, |
|
maximum=1024, |
|
value=512, |
|
step=1, |
|
interactive=True, |
|
label="Maximum number of new tokens to generate", |
|
), |
|
gr.Slider( |
|
minimum=0.01, |
|
maximum=0.99, |
|
value=0.8, |
|
step=0.01, |
|
interactive=True, |
|
label="Top P", |
|
info="Higher values is equivalent to sampling more low-probability tokens.", |
|
)],cache_examples=False |
|
) |
|
|
|
|
|
|
|
|
|
demo.launch(debug=True) |