File size: 3,230 Bytes
f745223
504b6c8
e8fb838
1ffd977
 
953413f
e8fb838
 
f745223
cbcb343
f745223
 
 
d8a82cd
52c453e
f745223
13a089e
 
269b919
 
 
 
 
 
 
 
 
19cbba1
 
 
 
 
fdc528c
19cbba1
 
5a9e82b
fdc528c
269b919
349875c
19cbba1
349875c
 
 
 
 
13a089e
 
 
 
 
 
 
 
 
a9db698
afed27d
13a089e
1ffd977
 
a9db698
f57923a
a9db698
afed27d
 
1ffd977
19cbba1
 
 
cf7aa4d
9b0bdb7
afed27d
 
 
9b0bdb7
a9db698
953413f
 
 
 
 
 
 
 
9b0bdb7
 
19cbba1
 
269b919
 
 
19cbba1
 
 
 
269b919
 
 
 
 
19cbba1
1ffd977
2334dc1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import os
import time
import torch
import gradio as gr

from strings import TITLE, ABSTRACT, EXAMPLES
from gen import get_pretrained_models, get_output, setup_model_parallel

os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "50505"

local_rank, world_size = setup_model_parallel()
generator = get_pretrained_models("7B", "tokenizer", local_rank, world_size)

history = []

def chat(
    user_input, 
    include_input,
    truncate,    
    top_p, 
    temperature, 
    max_gen_len, 
    state_chatbot
):
    bot_response = get_output(
        generator=generator, 
        prompt=user_input,
        max_gen_len=max_gen_len,
        temperature=temperature,
        top_p=top_p)[0]

    # remove the first phrase identical to user prompt
    if not include_input:
        bot_response = bot_response[len(user_input):]
    bot_response = bot_response.replace("\n", "<br>")
    
    # trip the last phrase
    if truncate:
        try:
            bot_response = bot_response[:bot_response.rfind(".")+1]
        except:
            pass

    history.append({
        "role": "user",
        "content": user_input
    })
    history.append({
        "role": "system",
        "content": bot_response
    })    

    state_chatbot = state_chatbot + [(user_input, None)]
    
    response = ""
    for word in bot_response.split(" "):
        time.sleep(0.1)
        response += word + " "
        current_pair = (user_input, response)
        state_chatbot[-1] = current_pair
        yield state_chatbot, state_chatbot

def reset_textbox():
    return gr.update(value='')

with gr.Blocks(css = """#col_container {width: 95%; margin-left: auto; margin-right: auto;}
                #chatbot {height: 400px; overflow: auto;}""") as demo:

    state_chatbot = gr.State([])
                    
    with gr.Column(elem_id='col_container'):
        gr.Markdown(f"## {TITLE}\n\n\n\n{ABSTRACT}")

        with gr.Accordion("Example prompts", open=False):
            example_str = "\n"
            for example in EXAMPLES:
                example_str += f"- {example}\n"
            
            gr.Markdown(example_str)        
        
        chatbot = gr.Chatbot(elem_id='chatbot')
        textbox = gr.Textbox(placeholder="Enter a prompt")

        with gr.Accordion("Parameters", open=False):
            include_input = gr.Checkbox(value=True, label="Do you want to include the input in the generated text?")
            truncate = gr.Checkbox(value=True, label="Truncate the unfinished last words?")
            
            max_gen_len = gr.Slider(minimum=20, maximum=512, value=256, step=1, interactive=True, label="Max Genenration Length",)
            top_p = gr.Slider(minimum=-0, maximum=1.0, value=1.0, step=0.05, interactive=True, label="Top-p (nucleus sampling)",)
            temperature = gr.Slider(minimum=-0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Temperature",)
        
    textbox.submit(
        chat, 
        [textbox, include_input, truncate, top_p, temperature, max_gen_len, state_chatbot],
        [state_chatbot, chatbot]
    )
    textbox.submit(reset_textbox, [], [textbox])

demo.queue(api_open=False).launch()