import torch
import gradio as gr
from uuid import uuid4
from threading import Thread
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from theme_dropdown import create_theme_dropdown
model_name = "RootYuan/RedLing-7B-v0.1"
max_new_tokens = 2048
device = 'cuda' if torch.cuda.is_available() else 'cpu'
DEFAULT_SYSTEM_MESSAGE = """
A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.
"""
VISION_TOKEN = ''
VISION_TOKENS = '\n' + VISION_TOKEN * 32 + '\n'
EOT_TOKEN = ""
PROMPT_TEMPLATE = "USER:{user}ASSISTANT:{assistant}{eos_token}"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
dropdown, js = create_theme_dropdown()
def get_uuid():
return str(uuid4())
def add_text(message, history):
# Append the user's message to the conversation history
return "", history + [[message, ""]]
def add_media(media, history):
media_name = media.name
media_format = media_name.split(".")[-1]
if media_format in ["jpg", "jpeg", "png"]:
media_type = "image"
history = history + [[(media_name, media_type), ""]]
return history
def convert_history_to_text(history):
conversations = []
add_vision_tokens = False
for item in history[:-1]:
if isinstance(item[0], tuple):
add_vision_tokens = True
else:
if add_vision_tokens:
conversation = PROMPT_TEMPLATE.format(
media=VISION_TOKENS,
user=item[0],
assistant=item[1],
eos_token=EOT_TOKEN,
)
add_vision_tokens = False
else:
conversation = PROMPT_TEMPLATE.format(
media='',
user=item[0],
assistant=item[1],
eos_token=EOT_TOKEN,
)
conversations.append(conversation)
text = "".join(conversations)
last = PROMPT_TEMPLATE.format(
media='',
user=history[-1][0],
assistant=history[-1][1],
eos_token='',
)
text += last
return text
def bot(history, temperature, top_k, sys_msg):
print(f"history: {history}")
# Construct the input message string for the model by concatenating the current system message and conversation history
messages = sys_msg + convert_history_to_text(history)
input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generation_kwargs = dict(
input_ids=input_ids,
temperature=temperature,
max_new_tokens=max_new_tokens,
top_k=top_k,
streamer=streamer,
)
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Initialize an empty string to store the generated text
generated_text = ""
for new_text in streamer:
generated_text += new_text
history[-1][1] = generated_text
yield history
with gr.Blocks(theme='sudeepshouche/minimalist') as demo:
with gr.Row().style(equal_height=True):
with gr.Column(scale=12):
gr.Markdown(
"""
# Visual Assistant Lab
"""
)
with gr.Column(scale=2):
with gr.Box():
dropdown.render()
toggle_dark = gr.Button(value="Toggle Dark").style(full_width=True)
dropdown.change(None, dropdown, None, _js=js)
toggle_dark.click(lambda: None, None, None, _js="() => {document.body.classList.toggle('dark')}")
# conversation_id = gr.State(get_uuid)
with gr.Row():
with gr.Accordion("System Message", open=False):
sys_msg = gr.Textbox(
value=DEFAULT_SYSTEM_MESSAGE,
label="System Message",
info="Instruct the AI Assistant to set its beaviour",
show_label=False,
)
with gr.Row():
chatbot = gr.Chatbot(label="Assistant").style(height=500)
with gr.Row():
with gr.Accordion("Advanced Settings:", open=False):
with gr.Row().style(equal_height=True):
with gr.Column():
temperature = gr.Slider(
label="Temperature",
value=0.1,
minimum=0.0,
maximum=1.0,
step=0.1,
interactive=True,
info="Higher values produce more diverse outputs",
)
with gr.Column():
top_k = gr.Slider(
label="Top-k",
value=0,
minimum=0.0,
maximum=200,
step=1,
interactive=True,
info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.",
)
with gr.Row().style(equal_height=True):
with gr.Column(scale=12):
msg = gr.Textbox(
label="Chat Message Box",
placeholder="Hi! Type here, Press [Enter] to send...",
show_label=False,
).style(container=False)
with gr.Column(scale=2):
send = gr.Button("Send")
with gr.Row().style(equal_height=True):
media = gr.UploadButton("Upload files", file_types=["image", "video", "audio"])
stop = gr.Button("Stop")
clear = gr.Button("Clear")
send_event = msg.submit(
fn=add_text,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=False,
).then(
fn=bot,
inputs=[chatbot, temperature, top_k, sys_msg],
outputs=chatbot,
queue=True,
)
media.upload(
fn=add_media,
inputs=[media, chatbot],
outputs=[chatbot],
)
send_click_event = send.click(
fn=add_text,
inputs=[msg, chatbot],
outputs=[msg, chatbot],
queue=False,
).then(
fn=bot,
inputs=[chatbot, temperature, top_k, sys_msg],
outputs=chatbot,
queue=True,
)
stop.click(
fn=None,
inputs=None,
outputs=None,
cancels=[send_event, send_click_event],
queue=False,
)
clear.click(lambda: None, None, chatbot, queue=False)
if __name__ == "__main__":
demo.queue(max_size=128, concurrency_count=2)
demo.launch()