qwen-vl / app.py
multimodalart's picture
UI refinements
81d45ed
raw
history blame
6.19 kB
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from PIL import Image
import re
import requests
from io import BytesIO
import copy
import secrets
from pathlib import Path
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-VL-Chat-Int4", trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-VL-Chat-Int4", device_map="auto", trust_remote_code=True).eval()
BOX_TAG_PATTERN = r"<box>([\s\S]*?)</box>"
PUNCTUATION = "!?。"#$%&'()*+,-/:;<=>@[\]^_`{|}~⦅⦆「」、、〃》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏."
def _parse_text(text):
lines = text.split("\n")
lines = [line for line in lines if line != ""]
count = 0
for i, line in enumerate(lines):
if "```" in line:
count += 1
items = line.split("`")
if count % 2 == 1:
lines[i] = f'<pre><code class="language-{items[-1]}">'
else:
lines[i] = f"<br></code></pre>"
else:
if i > 0:
if count % 2 == 1:
line = line.replace("`", r"\`")
line = line.replace("<", "&lt;")
line = line.replace(">", "&gt;")
line = line.replace(" ", "&nbsp;")
line = line.replace("*", "&ast;")
line = line.replace("_", "&lowbar;")
line = line.replace("-", "&#45;")
line = line.replace(".", "&#46;")
line = line.replace("!", "&#33;")
line = line.replace("(", "&#40;")
line = line.replace(")", "&#41;")
line = line.replace("$", "&#36;")
lines[i] = "<br>" + line
text = "".join(lines)
return text
def predict(_chatbot, task_history):
chat_query = _chatbot[-1][0]
query = task_history[-1][0]
history_cp = copy.deepcopy(task_history)
full_response = ""
history_filter = []
pic_idx = 1
pre = ""
for i, (q, a) in enumerate(history_cp):
if isinstance(q, (tuple, list)):
q = f'Picture {pic_idx}: <img>{q[0]}</img>'
pre += q + '\n'
pic_idx += 1
else:
pre += q
history_filter.append((pre, a))
pre = ""
history, message = history_filter[:-1], history_filter[-1][0]
response, history = model.chat(tokenizer, message, history=history)
image = tokenizer.draw_bbox_on_latest_picture(response, history)
if image is not None:
temp_dir = secrets.token_hex(20)
temp_dir = Path("/tmp") / temp_dir
temp_dir.mkdir(exist_ok=True, parents=True)
name = f"tmp{secrets.token_hex(5)}.jpg"
filename = temp_dir / name
image.save(str(filename))
_chatbot[-1] = (_parse_text(chat_query), (str(filename),))
chat_response = response.replace("<ref>", "")
chat_response = chat_response.replace(r"</ref>", "")
chat_response = re.sub(BOX_TAG_PATTERN, "", chat_response)
if chat_response != "":
_chatbot.append((None, chat_response))
else:
_chatbot[-1] = (_parse_text(chat_query), response)
full_response = _parse_text(response)
task_history[-1] = (query, full_response)
return _chatbot
def add_text(history, task_history, text):
task_text = text
if len(text) >= 2 and text[-1] in PUNCTUATION and text[-2] not in PUNCTUATION:
task_text = text[:-1]
history = history + [(_parse_text(text), None)]
task_history = task_history + [(task_text, None)]
return history, task_history, ""
def add_file(history, task_history, file):
history = history + [((file.name,), None)]
task_history = task_history + [((file.name,), None)]
return history, task_history
def reset_user_input():
return gr.update(value="")
def reset_state(task_history):
task_history.clear()
return []
def regenerate(_chatbot, task_history):
print("Regenerate clicked")
print("Before:", task_history, _chatbot)
if not task_history:
return _chatbot
item = task_history[-1]
if item[1] is None:
return _chatbot
task_history[-1] = (item[0], None)
chatbot_item = _chatbot.pop(-1)
if chatbot_item[0] is None:
_chatbot[-1] = (_chatbot[-1][0], None)
else:
_chatbot.append((chatbot_item[0], None))
print("After:", task_history, _chatbot)
return predict(_chatbot, task_history)
css = '''
.gradio-container{max-width:800px}
'''
with gr.Blocks(css) as demo:
gr.Markdown("# Qwen-VL-Chat Bot")
gr.Markdown("## Qwen-VL: A Multimodal Large Vision Language Model by Alibaba Cloud **Space by [@Artificialguybr](https://twitter.com/artificialguybr)</center>")
chatbot = gr.Chatbot(label='Qwen-VL-Chat', elem_classes="control-height", height=520)
query = gr.Textbox(lines=2, label='Input')
task_history = gr.State([])
with gr.Row():
addfile_btn = gr.UploadButton("📁 Upload", file_types=["image"])
submit_btn = gr.Button("🚀 Submit")
regen_btn = gr.Button("🤔️ Regenerate")
empty_bin = gr.Button("🧹 Clear History")
gr.Markdown("### Key Features:\n- **Strong Performance**: Surpasses existing LVLMs on multiple English benchmarks including Zero-shot Captioning and VQA.\n- **Multi-lingual Support**: Supports English, Chinese, and multi-lingual conversation.\n- **High Resolution**: Utilizes 448*448 resolution for fine-grained recognition and understanding.")
submit_btn.click(add_text, [chatbot, task_history, query], [chatbot, task_history]).then(
predict, [chatbot, task_history], [chatbot], show_progress=True
)
submit_btn.click(reset_user_input, [], [query])
empty_bin.click(reset_state, [task_history], [chatbot], show_progress=True)
regen_btn.click(regenerate, [chatbot, task_history], [chatbot], show_progress=True)
addfile_btn.upload(add_file, [chatbot, task_history, addfile_btn], [chatbot, task_history], show_progress=True)
demo.launch()