|
import gradio as gr |
|
from uuid import uuid4 |
|
from threading import Thread |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer |
|
|
|
from theme_dropdown import create_theme_dropdown |
|
|
|
|
|
model_name = "RootYuan/RootYuan-RedLing-7B-v0.1" |
|
max_new_tokens = 2048 |
|
device = 'cpu' |
|
|
|
|
|
DEFAULT_SYSTEM_MESSAGE = """ |
|
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. |
|
""" |
|
|
|
VISION_TOKEN = '<img>' |
|
VISION_TOKENS = '\n' + VISION_TOKEN * 32 + '\n' |
|
EOT_TOKEN = "<EOT>" |
|
|
|
PROMPT_TEMPLATE = "USER:{user}<EOT>ASSISTANT:{assistant}{eos_token}" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForCausalLM.from_pretrained(model_name).to(device) |
|
|
|
dropdown, js = create_theme_dropdown() |
|
|
|
def get_uuid(): |
|
return str(uuid4()) |
|
|
|
|
|
def add_text(message, history): |
|
|
|
return "", history + [[message, ""]] |
|
|
|
|
|
def add_media(media, history): |
|
media_name = media.name |
|
media_format = media_name.split(".")[-1] |
|
if media_format in ["jpg", "jpeg", "png"]: |
|
media_type = "image" |
|
history = history + [[(media_name, media_type), ""]] |
|
return history |
|
|
|
|
|
def convert_history_to_text(history): |
|
conversations = [] |
|
add_vision_tokens = False |
|
for item in history[:-1]: |
|
if isinstance(item[0], tuple): |
|
add_vision_tokens = True |
|
else: |
|
if add_vision_tokens: |
|
conversation = PROMPT_TEMPLATE.format( |
|
media=VISION_TOKENS, |
|
user=item[0], |
|
assistant=item[1], |
|
eos_token=EOT_TOKEN, |
|
) |
|
add_vision_tokens = False |
|
else: |
|
conversation = PROMPT_TEMPLATE.format( |
|
media='', |
|
user=item[0], |
|
assistant=item[1], |
|
eos_token=EOT_TOKEN, |
|
) |
|
conversations.append(conversation) |
|
|
|
text = "".join(conversations) |
|
last = PROMPT_TEMPLATE.format( |
|
media='', |
|
user=history[-1][0], |
|
assistant=history[-1][1], |
|
eos_token='', |
|
) |
|
text += last |
|
|
|
return text |
|
|
|
|
|
def bot(history, temperature, top_k, sys_msg): |
|
print(f"history: {history}") |
|
|
|
|
|
messages = sys_msg + convert_history_to_text(history) |
|
input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(device) |
|
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) |
|
generation_kwargs = dict( |
|
input_ids=input_ids, |
|
temperature=temperature, |
|
max_new_tokens=max_new_tokens, |
|
top_k=top_k, |
|
streamer=streamer, |
|
) |
|
|
|
thread = Thread(target=model.generate, kwargs=generation_kwargs) |
|
thread.start() |
|
|
|
|
|
generated_text = "" |
|
for new_text in streamer: |
|
generated_text += new_text |
|
history[-1][1] = generated_text |
|
yield history |
|
|
|
|
|
with gr.Blocks(theme='sudeepshouche/minimalist') as demo: |
|
with gr.Row().style(equal_height=True): |
|
with gr.Column(scale=12): |
|
gr.Markdown( |
|
""" |
|
# Visual Assistant Lab |
|
""" |
|
) |
|
with gr.Column(scale=2): |
|
with gr.Box(): |
|
dropdown.render() |
|
toggle_dark = gr.Button(value="Toggle Dark").style(full_width=True) |
|
dropdown.change(None, dropdown, None, _js=js) |
|
toggle_dark.click(lambda: None, None, None, _js="() => {document.body.classList.toggle('dark')}") |
|
|
|
with gr.Row(): |
|
with gr.Accordion("System Message", open=False): |
|
sys_msg = gr.Textbox( |
|
value=DEFAULT_SYSTEM_MESSAGE, |
|
label="System Message", |
|
info="Instruct the AI Assistant to set its beaviour", |
|
show_label=False, |
|
) |
|
with gr.Row(): |
|
chatbot = gr.Chatbot(label="Assistant").style(height=500) |
|
with gr.Row(): |
|
with gr.Accordion("Advanced Settings:", open=False): |
|
with gr.Row().style(equal_height=True): |
|
with gr.Column(): |
|
temperature = gr.Slider( |
|
label="Temperature", |
|
value=0.1, |
|
minimum=0.0, |
|
maximum=1.0, |
|
step=0.1, |
|
interactive=True, |
|
info="Higher values produce more diverse outputs", |
|
) |
|
with gr.Column(): |
|
top_k = gr.Slider( |
|
label="Top-k", |
|
value=0, |
|
minimum=0.0, |
|
maximum=200, |
|
step=1, |
|
interactive=True, |
|
info="Sample from a shortlist of top-k tokens β 0 to disable and sample from all tokens.", |
|
) |
|
with gr.Row().style(equal_height=True): |
|
with gr.Column(scale=12): |
|
msg = gr.Textbox( |
|
label="Chat Message Box", |
|
placeholder="Hi! Type here, Press [Enter] to send...", |
|
show_label=False, |
|
).style(container=False) |
|
with gr.Column(scale=2): |
|
send = gr.Button("Send") |
|
with gr.Row().style(equal_height=True): |
|
media = gr.UploadButton("Upload files", file_types=["image", "video", "audio"]) |
|
stop = gr.Button("Stop") |
|
clear = gr.Button("Clear") |
|
|
|
send_event = msg.submit( |
|
fn=add_text, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[chatbot, temperature, top_k, sys_msg], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
media.upload( |
|
fn=add_media, |
|
inputs=[media, chatbot], |
|
outputs=[chatbot], |
|
) |
|
|
|
send_click_event = send.click( |
|
fn=add_text, |
|
inputs=[msg, chatbot], |
|
outputs=[msg, chatbot], |
|
queue=False, |
|
).then( |
|
fn=bot, |
|
inputs=[chatbot, temperature, top_k, sys_msg], |
|
outputs=chatbot, |
|
queue=True, |
|
) |
|
|
|
stop.click( |
|
fn=None, |
|
inputs=None, |
|
outputs=None, |
|
cancels=[send_event, send_click_event], |
|
queue=False, |
|
) |
|
clear.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.queue(max_size=128, concurrency_count=2) |
|
demo.launch() |