VL-Chatbox / app.py
vilarin's picture
Update app.py
09399fd verified
raw
history blame
3.72 kB
from threading import Thread
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModelForCausalLM, AutoProcessor,TextIteratorStreamer
import os
import time
from huggingface_hub import hf_hub_download
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = "<h3><center>MODEL: " + MODEL_NAME + "</center></h3>"
DEFAULT_SYSTEM = "You named Chatbox. You are a good assitant."
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
trust_remote_code=True
).to(0)
processor = AutoProcessor.from_pretrained(MODEL_ID, trust_remote_code=True)
eos_token_id=processor.tokenizer.eos_token_id
@spaces.GPU(duration=120, queue=False)
def stream_chat(message, history: list, system: str, temperature: float, max_new_tokens: int):
print(message)
conversation = [{"role": "system", "content": system or DEFAULT_SYSTEM}]
for prompt, answer in history:
conversation.extend([{"role": "user", "content": f"<|image_1|>\n{prompt}"}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message['text']})
if message["files"]:
image = Image.open(message["files"][0]).convert('RGB')
else:
image = None
prompt = processor.tokenizer.apply_chat_template(conversation, tokenize=False, add_generation_prompt=True)
inputs = processor(prompt, [image], return_tensors="pt").to(0)
generate_kwargs = dict(
max_new_tokens=max_new_tokens,
temperature=temperature,
do_sample=True,
eos_token_id=eos_token_id,
)
if temperature == 0:
generate_kwargs["do_sample"] = False
generate_kwargs = {**inputs, **generate_kwargs}
streamer = TextIteratorStreamer(processor, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
thread = Thread(target=model.generate, kwargs=generate_kwargs)
thread.start()
buffer = ""
for new_text in streamer:
buffer += new_text
yield buffer
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(interactive=True, file_types=["image"], placeholder="Enter message or upload file...", show_label=False)
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
chatbot=chatbot,
textbox=chat_input,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Text(
value="",
label="System",
render=False,
),
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
)
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)