Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,161 Bytes
0b17160 aa179a2 36f093d 0b17160 b7864ca 0b17160 b7864ca 0b17160 aa179a2 588861c aa179a2 0b17160 e4536f3 0b17160 e4536f3 67e4fcc e4536f3 0b17160 e4536f3 b414d97 0b17160 3aabbc8 0b17160 aa179a2 0b17160 aa179a2 0b17160 36f093d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 |
from __future__ import annotations
import spaces
import gradio as gr
from threading import Thread
from transformers import TextIteratorStreamer
import hashlib
import os
from transformers import AutoModel, AutoProcessor
import torch
model = AutoModel.from_pretrained("visheratin/MC-LLaVA-3b", torch_dtype=torch.float16, trust_remote_code=True).to("cuda")
processor = AutoProcessor.from_pretrained("visheratin/MC-LLaVA-3b", trust_remote_code=True)
if torch.cuda.is_available():
DEVICE = "cuda"
DTYPE = torch.float16
else:
DEVICE = "cpu"
DTYPE = torch.float32
def cached_vision_process(image, max_crops, num_tokens):
image_hash = hashlib.sha256(image.tobytes()).hexdigest()
cache_path = f"visual_cache/{image_hash}-{max_crops}-{num_tokens}.pt"
if os.path.exists(cache_path):
return torch.load(cache_path).to(DEVICE, dtype=DTYPE)
else:
processor_outputs = processor.image_processor([image], max_crops)
pixel_values = processor_outputs["pixel_values"]
pixel_values = [
value.to(model.device).to(model.dtype) for value in pixel_values
]
coords = processor_outputs["coords"]
coords = [value.to(model.device).to(model.dtype) for value in coords]
image_outputs = model.vision_model(pixel_values, coords, num_tokens)
image_features = model.multi_modal_projector(image_outputs)
os.makedirs("visual_cache", exist_ok=True)
torch.save(image_features, cache_path)
return image_features.to(DEVICE, dtype=DTYPE)
@spaces.GPU(duration=20)
def answer_question(image, question, max_crops, num_tokens, sample, temperature, top_k):
if question is None or question.strip() == "":
yield "Please ask a question"
return
if image is None:
yield "Please upload an image"
return
prompt = f"""<|im_start|>user
<image>
{question}<|im_end|>
<|im_start|>assistant
"""
streamer = TextIteratorStreamer(processor.tokenizer, skip_special_tokens=True)
with torch.inference_mode():
inputs = processor(prompt, [image], model, max_crops=max_crops, num_tokens=num_tokens)
generation_kwargs = {
"input_ids": inputs["input_ids"],
"attention_mask": inputs["attention_mask"],
"image_features": inputs["image_features"],
"streamer": streamer,
"max_length": 1000,
"use_cache": True,
"eos_token_id": processor.tokenizer.eos_token_id,
"pad_token_id": processor.tokenizer.eos_token_id,
"temperature": temperature,
"do_sample": sample,
"top_k": top_k,
}
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
buffer = ""
output_started = False
for new_text in streamer:
if not output_started:
if "<|im_start|>assistant" in new_text:
output_started = True
continue
buffer += new_text
if len(buffer) > 1:
yield buffer
return buffer
with gr.Blocks() as demo:
gr.HTML("<h1 class='gradio-heading'><center>MC-LLaVA 3B</center></h1>")
gr.HTML(
"<center><p class='gradio-sub-heading'>MC-LLaVA 3B is a model that can answer questions about small details in high-resolution images. Check out the <a href='https://huggingface.co/visheratin/MC-LLaVA-3b'>model card</a> for more details. If you have any questions or ideas hot to make the model better, <a href='https://x.com/visheratin'>let me know</a>.</p></center>"
)
gr.HTML(
"<center><p class='gradio-sub-heading'>There are two main parameters - max number of crops and number of image tokens. The first one controls into how many parts will the image be cut. This is especially useful when you are working with high-resolution images. The second parameter controls how many image features will be extracted for LLM to be processed. You can increase it if you are trying to extract info from a small part of the image, e.g., text.</p></center>"
)
with gr.Group():
with gr.Row():
prompt = gr.Textbox(
label="Question", placeholder="e.g. What is this?", scale=4
)
submit = gr.Button(
"Submit",
scale=1,
)
with gr.Row():
max_crops = gr.Slider(minimum=0, maximum=200, step=5, value=0, label="Max crops")
num_tokens = gr.Slider(minimum=728, maximum=2184, step=10, value=728, label="Number of image tokens")
with gr.Row():
img = gr.Image(type="pil", label="Upload or Drag an Image")
output = gr.TextArea(label="Answer")
with gr.Row():
sample = gr.Checkbox(label="Sample", value=False)
temperature = gr.Slider(minimum=0, maximum=1, step=0.1, value=0, label="Temperature")
top_k = gr.Slider(minimum=0, maximum=50, step=1, value=0, label="Top-K")
submit.click(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
prompt.submit(answer_question, [img, prompt, max_crops, num_tokens, sample, temperature, top_k], output)
demo.queue().launch(debug=True) |