import os import time import torch import gradio as gr from strings import TITLE, ABSTRACT from gen import get_pretrained_models, get_output, setup_model_parallel os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "50505" local_rank, world_size = setup_model_parallel() generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size) history = [] simple_history = [] def chat(user_input): bot_response = get_output(generator, user_input)[0] history.append({ "role": "user", "content": user_input }) history.append({ "role": "system", "content": bot_response }) simple_history.append((user_input, None)) response = "" for word in bot_response.split(" "): time.sleep(0.1) response += word + " " current_pair = (user_input, response) simple_history[-1] = current_pair yield simple_history with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;} #chatbot {height: 400px; overflow: auto;}""") as demo: with gr.Column(elem_id='col_container'): gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}") chatbot = gr.Chatbot(elem_id='chatbot') textbox = gr.Textbox(placeholder="Enter a prompt") textbox.submit(chat, textbox, chatbot) demo.queue(api_open=False).launch()