File size: 1,455 Bytes
f745223
504b6c8
e8fb838
1ffd977
 
13a089e
e8fb838
 
f745223
cbcb343
f745223
 
 
d8a82cd
52c453e
f745223
13a089e
a9db698
13a089e
1ffd977
3988351
13a089e
 
 
 
 
 
 
 
 
a9db698
 
13a089e
1ffd977
 
a9db698
f57923a
a9db698
 
 
1ffd977
cf7aa4d
9b0bdb7
a9db698
9b0bdb7
a9db698
9b0bdb7
 
13a089e
9b0bdb7
1ffd977
2334dc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import time
import torch
import gradio as gr

from strings import TITLE, ABSTRACT 
from gen import get_pretrained_models, get_output, setup_model_parallel

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"

local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)

history = []
simple_history = []

def chat(user_input):
    bot_response = get_output(generator, user_input)[0]

    history.append({
        "role": "user",
        "content": user_input
    })
    history.append({
        "role": "system",
        "content": bot_response
    })    

    simple_history.append((user_input, None))
    
    response = ""
    for word in bot_response.split(" "):
        time.sleep(0.1)
        response += word + " "
        current_pair = (user_input, response)
        simple_history[-1] = current_pair
        yield simple_history

with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;}
                #chatbot {height: 400px; overflow: auto;}""") as demo:
    
    with gr.Column(elem_id='col_container'):
        gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")
        chatbot = gr.Chatbot(elem_id='chatbot')
        textbox = gr.Textbox(placeholder="Enter a prompt")
    
        textbox.submit(chat, textbox, chatbot)

demo.queue(api_open=False).launch()