yentinglin commited on
Commit
ca877b2
·
1 Parent(s): 628629b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -99
app.py CHANGED
@@ -1,120 +1,72 @@
 
 
1
  import os
2
-
3
  import gradio as gr
4
  from text_generation import Client
5
- from conversation import get_default_conv_template, SeparatorStyle
6
 
7
 
 
 
8
  eos_token = "</s>"
9
 
10
- def _concat_messages(messages):
11
- message_text = ""
12
- for message in messages:
13
- if message["role"] == "system":
14
- message_text += "<|system|>\n" + message["content"].strip() + "\n"
15
- elif message["role"] == "user":
16
- message_text += "<|user|>\n" + message["content"].strip() + "\n"
17
- elif message["role"] == "assistant":
18
- message_text += "<|assistant|>\n" + message["content"].strip() + eos_token + "\n"
19
- else:
20
- raise ValueError("Invalid role: {}".format(message["role"]))
21
- return message_text
22
 
23
- endpoint_url = os.environ.get("ENDPOINT_URL")
24
- client = Client(endpoint_url, timeout=120)
25
 
26
- def generate_response(user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty):
27
- user_input = user_input.strip()
28
  conv = get_default_conv_template("vicuna").copy()
29
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
30
- role = roles["human"]
31
- conv.append_message(role, user_input)
32
- conv.append_message(roles["gpt"], None)
33
  msg = conv.get_prompt()
34
 
35
- res = client.generate(
36
- msg,
37
- stop_sequences=["<|assistant|>", eos_token, "<|system|>", "<|user|>"],
38
- max_new_tokens=max_new_token,
39
- top_p=top_p,
40
- top_k=top_k,
41
- do_sample=do_sample,
42
- temperature=temperature,
43
- repetition_penalty=repetition_penalty,
44
- )
45
- return [("assistant", res.generated_text)]
46
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with gr.Blocks() as demo:
48
  chatbot = gr.Chatbot()
49
- with gr.Row():
50
- with gr.Column(scale=4):
51
- with gr.Column(scale=12):
52
- user_input = gr.Textbox(
53
- show_label=False,
54
- placeholder="Shift + Enter傳送...",
55
- lines=10).style(
56
- container=False)
57
- with gr.Column(min_width=32, scale=1):
58
- submitBtn = gr.Button("Submit", variant="primary")
59
- with gr.Column(scale=1):
60
- emptyBtn = gr.Button("Clear History")
61
- max_new_token = gr.Slider(
62
- 1,
63
- 1024,
64
- value=128,
65
- step=1.0,
66
- label="Maximum New Token Length",
67
- interactive=True)
68
- top_p = gr.Slider(0, 1, value=0.9, step=0.01,
69
- label="Top P", interactive=True)
70
- temperature = gr.Slider(
71
- 0,
72
- 1,
73
- value=0.5,
74
- step=0.01,
75
- label="Temperature",
76
- interactive=True)
77
- top_k = gr.Slider(1, 40, value=40, step=1,
78
- label="Top K", interactive=True)
79
- do_sample = gr.Checkbox(
80
- value=True,
81
- label="Do Sample",
82
- info="use random sample strategy",
83
- interactive=True)
84
- repetition_penalty = gr.Slider(
85
- 1.0,
86
- 3.0,
87
- value=1.1,
88
- step=0.1,
89
- label="Repetition Penalty",
90
- interactive=True)
91
 
92
- params = [user_input, chatbot]
93
- predict_params = [
94
- chatbot,
95
- max_new_token,
96
- top_p,
97
- temperature,
98
- top_k,
99
- do_sample,
100
- repetition_penalty]
101
 
102
- submitBtn.click(
103
- generate_response,
104
- [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
105
- [chatbot],
106
- queue=False
107
- )
108
 
109
- user_input.submit(
110
- generate_response,
111
- [user_input, max_new_token, top_p, top_k, temperature, do_sample, repetition_penalty],
112
- [chatbot],
113
- queue=False
114
- )
 
 
115
 
116
- submitBtn.click(lambda: None, [], [user_input])
117
-
118
- emptyBtn.click(lambda: chatbot.reset(), outputs=[chatbot], show_progress=True)
 
 
 
 
119
 
120
- demo.launch()
 
1
+ import random
2
+ import time
3
  import os
 
4
  import gradio as gr
5
  from text_generation import Client
6
+ from conversation import get_default_conv_template
7
 
8
 
9
+ endpoint_url = os.environ.get("ENDPOINT_URL")
10
+ client = Client(endpoint_url, timeout=120)
11
  eos_token = "</s>"
12
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
 
 
14
 
15
+ def generate_response(history, max_new_token=512, top_p=0.9, temperature=0.8, do_sample=True):
 
16
  conv = get_default_conv_template("vicuna").copy()
17
  roles = {"human": conv.roles[0], "gpt": conv.roles[1]} # map human to USER and gpt to ASSISTANT
18
+ for user, bot in history:
19
+ conv.append_message(roles['human'], user)
20
+ conv.append_message(roles["gpt"], bot)
21
  msg = conv.get_prompt()
22
 
23
+ for response in client.generate_stream(
24
+ msg,
25
+ max_new_tokens=max_new_token,
26
+ top_p=top_p,
27
+ temperature=temperature,
28
+ do_sample=do_sample,
29
+ ):
30
+ if not response.token.special:
31
+ yield response.token.text
 
 
32
 
33
+ # res = client.generate(
34
+ # msg,
35
+ # stop_sequences=["<|assistant|>", eos_token, "<|system|>", "<|user|>"],
36
+ # max_new_tokens=max_new_token,
37
+ # top_p=top_p,
38
+ # top_k=top_k,
39
+ # do_sample=do_sample,
40
+ # temperature=temperature,
41
+ # repetition_penalty=repetition_penalty,
42
+ # )
43
+ # return [("assistant", res.generated_text)]
44
+ #
45
  with gr.Blocks() as demo:
46
  chatbot = gr.Chatbot()
47
+ msg = gr.Textbox()
48
+ clear = gr.Button("Clear")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49
 
50
+ def user(user_message, history):
51
+ return "", history + [[user_message, None]]
 
 
 
 
 
 
 
52
 
53
+ def bot(history):
54
+ # history = list of [[user_message, bot_message], ...]
 
 
 
 
55
 
56
+ import ipdb
57
+ ipdb.set_trace()
58
+ bot_message = random.choice(["How are you?", "I love you", "I'm very hungry"])
59
+ history[-1][1] = ""
60
+ for character in bot_message:
61
+ history[-1][1] += character
62
+ time.sleep(0.05)
63
+ yield history
64
 
65
+ msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
66
+ generate_response, chatbot, chatbot
67
+ )
68
+ clear.click(lambda: None, None, chatbot, queue=False)
69
+
70
+ demo.queue()
71
+ demo.launch()
72