VL-Chatbox / app.py
vilarin's picture
Update app.py
d616ff6 verified
raw
history blame
No virus
3.87 kB
import torch
from PIL import Image
import gradio as gr
import spaces
from transformers import AutoModel, AutoTokenizer
import os
os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "1"
MODEL_LIST = ["openbmb/MiniCPM-Llama3-V-2_5","openbmb/MiniCPM-Llama3-V-2_5-int4"]
HF_TOKEN = os.environ.get("HF_TOKEN", None)
MODEL_ID = os.environ.get("MODEL_ID")
MODEL_NAME = MODEL_ID.split("/")[-1]
TITLE = "<h1><center>VL-Chatbox</center></h1>"
DESCRIPTION = f'<h3><center>MODEL: <a href="https://hf.co/{MODEL_ID}">{MODEL_NAME}</a></center></h3>'
CSS = """
.duplicate-button {
margin: auto !important;
color: white !important;
background: black !important;
border-radius: 100vh !important;
}
"""
model = AutoModel.from_pretrained(
MODEL_ID,
torch_dtype=torch.float16,
trust_remote_code=True
).to(0)
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)
model.eval()
@spaces.GPU()
def stream_chat(message, history: list, temperature: float, max_new_tokens: int):
print(f'message is - {message}')
print(f'history is - {history}')
conversation = []
if message["files"]:
image = Image.open(message["files"][-1]).convert('RGB')
conversation.append({"role": "user", "content": message['text']})
else:
if len(history) == 0:
raise gr.Error("Please upload an image first.")
image = None
else:
image = Image.open(history[0][0][0])
for prompt, answer in history:
if answer is None:
conversation.extend([{"role": "user", "content": prompt},{"role": "assistant", "content": ""}])
else:
conversation.extend([{"role": "user", "content": prompt}, {"role": "assistant", "content": answer}])
conversation.append({"role": "user", "content": message['text']})
print(f"Conversation is -\n{conversation}")
generate_kwargs = dict(
image=image,
msgs=conversation,
max_new_tokens=max_new_tokens,
temperature=temperature,
sampling=True,
tokenizer=tokenizer,
stream=True
)
if temperature == 0:
generate_kwargs["sampling"] = False
response = model.chat(**generate_kwargs)
generated_text = ""
for new_text in response:
generated_text += new_text
yeild(new_text, flush=True, end='')
chatbot = gr.Chatbot(height=450)
chat_input = gr.MultimodalTextbox(
interactive=True,
file_types=["image"],
placeholder="Enter message or upload file...",
show_label=False,
)
EXAMPLES = [
[{"text": "Describe it in great detailed.", "files": ["./laptop.jpg"]}],
[{"text": "Describe it in great detailed.", "files": ["./hotel.jpg"]}],
[{"text": "Describe it in great detailed.", "files": ["./spacecat.png"]}]
]
with gr.Blocks(css=CSS) as demo:
gr.HTML(TITLE)
gr.HTML(DESCRIPTION)
gr.DuplicateButton(value="Duplicate Space for private use", elem_classes="duplicate-button")
gr.ChatInterface(
fn=stream_chat,
multimodal=True,
textbox=chat_input,
chatbot=chatbot,
fill_height=True,
additional_inputs_accordion=gr.Accordion(label="⚙️ Parameters", open=False, render=False),
additional_inputs=[
gr.Slider(
minimum=0,
maximum=1,
step=0.1,
value=0.8,
label="Temperature",
render=False,
),
gr.Slider(
minimum=128,
maximum=4096,
step=1,
value=1024,
label="Max new tokens",
render=False,
),
],
),
gr.Examples(EXAMPLES,[chat_input])
if __name__ == "__main__":
demo.queue(api_open=False).launch(show_api=False, share=False)