import torch import gradio as gr from uuid import uuid4 from threading import Thread from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer from theme_dropdown import create_theme_dropdown model_name = "RootYuan/RedLing-7B-v0.1" max_new_tokens = 2048 device = 'cuda' if torch.cuda.is_available() else 'cpu' DEFAULT_SYSTEM_MESSAGE = """ A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ VISION_TOKEN = '' VISION_TOKENS = '\n' + VISION_TOKEN * 32 + '\n' EOT_TOKEN = "" PROMPT_TEMPLATE = "USER:{user}ASSISTANT:{assistant}{eos_token}" tokenizer = AutoTokenizer.from_pretrained(model_name) model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device) dropdown, js = create_theme_dropdown() def get_uuid(): return str(uuid4()) def add_text(message, history): # Append the user's message to the conversation history return "", history + [[message, ""]] def add_media(media, history): media_name = media.name media_format = media_name.split(".")[-1] if media_format in ["jpg", "jpeg", "png"]: media_type = "image" history = history + [[(media_name, media_type), ""]] return history def convert_history_to_text(history): conversations = [] add_vision_tokens = False for item in history[:-1]: if isinstance(item[0], tuple): add_vision_tokens = True else: if add_vision_tokens: conversation = PROMPT_TEMPLATE.format( media=VISION_TOKENS, user=item[0], assistant=item[1], eos_token=EOT_TOKEN, ) add_vision_tokens = False else: conversation = PROMPT_TEMPLATE.format( media='', user=item[0], assistant=item[1], eos_token=EOT_TOKEN, ) conversations.append(conversation) text = "".join(conversations) last = PROMPT_TEMPLATE.format( media='', user=history[-1][0], assistant=history[-1][1], eos_token='', ) text += last return text def bot(history, temperature, top_k, sys_msg): print(f"history: {history}") # Construct the input message string for the model by concatenating the current system message and conversation history messages = sys_msg + convert_history_to_text(history) input_ids = tokenizer(messages, return_tensors="pt").input_ids.to(device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generation_kwargs = dict( input_ids=input_ids, temperature=temperature, max_new_tokens=max_new_tokens, top_k=top_k, streamer=streamer, ) thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Initialize an empty string to store the generated text generated_text = "" for new_text in streamer: generated_text += new_text history[-1][1] = generated_text yield history with gr.Blocks(theme='sudeepshouche/minimalist') as demo: with gr.Row().style(equal_height=True): with gr.Column(scale=12): gr.Markdown( """ # Visual Assistant Lab """ ) with gr.Column(scale=2): with gr.Box(): dropdown.render() toggle_dark = gr.Button(value="Toggle Dark").style(full_width=True) dropdown.change(None, dropdown, None, _js=js) toggle_dark.click(lambda: None, None, None, _js="() => {document.body.classList.toggle('dark')}") # conversation_id = gr.State(get_uuid) with gr.Row(): with gr.Accordion("System Message", open=False): sys_msg = gr.Textbox( value=DEFAULT_SYSTEM_MESSAGE, label="System Message", info="Instruct the AI Assistant to set its beaviour", show_label=False, ) with gr.Row(): chatbot = gr.Chatbot(label="Assistant").style(height=500) with gr.Row(): with gr.Accordion("Advanced Settings:", open=False): with gr.Row().style(equal_height=True): with gr.Column(): temperature = gr.Slider( label="Temperature", value=0.1, minimum=0.0, maximum=1.0, step=0.1, interactive=True, info="Higher values produce more diverse outputs", ) with gr.Column(): top_k = gr.Slider( label="Top-k", value=0, minimum=0.0, maximum=200, step=1, interactive=True, info="Sample from a shortlist of top-k tokens — 0 to disable and sample from all tokens.", ) with gr.Row().style(equal_height=True): with gr.Column(scale=12): msg = gr.Textbox( label="Chat Message Box", placeholder="Hi! Type here, Press [Enter] to send...", show_label=False, ).style(container=False) with gr.Column(scale=2): send = gr.Button("Send") with gr.Row().style(equal_height=True): media = gr.UploadButton("Upload files", file_types=["image", "video", "audio"]) stop = gr.Button("Stop") clear = gr.Button("Clear") send_event = msg.submit( fn=add_text, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).then( fn=bot, inputs=[chatbot, temperature, top_k, sys_msg], outputs=chatbot, queue=True, ) media.upload( fn=add_media, inputs=[media, chatbot], outputs=[chatbot], ) send_click_event = send.click( fn=add_text, inputs=[msg, chatbot], outputs=[msg, chatbot], queue=False, ).then( fn=bot, inputs=[chatbot, temperature, top_k, sys_msg], outputs=chatbot, queue=True, ) stop.click( fn=None, inputs=None, outputs=None, cancels=[send_event, send_click_event], queue=False, ) clear.click(lambda: None, None, chatbot, queue=False) if __name__ == "__main__": demo.queue(max_size=128, concurrency_count=2) demo.launch()