manu commited on
Commit
808d645
1 Parent(s): 24e2c73

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +294 -0
app.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import gradio as gr
4
+ from text_generation import Client
5
+
6
+ # HF-hosted endpoint for testing purposes (requires an HF API token)
7
+ API_TOKEN = os.environ.get("API_TOKEN", None)
8
+
9
+ CURRENT_CLIENT = Client("https://afrts4trc759c6eq.us-east-1.aws.endpoints.huggingface.cloud/generate_stream",
10
+ timeout=120,
11
+ headers={
12
+ "Accept": "application/json",
13
+ "Authorization": f"Bearer {API_TOKEN}",
14
+ "Content-Type": "application/json"}
15
+ )
16
+
17
+ DEFAULT_HEADER = os.environ.get("HEADER", "")
18
+ DEFAULT_USER_NAME = os.environ.get("USER_NAME", "user")
19
+ DEFAULT_ASSISTANT_NAME = os.environ.get("ASSISTANT_NAME", "assistant")
20
+ DEFAULT_SEPARATOR = os.environ.get("SEPARATOR", "<|im_end|>")
21
+ PROMPT_TEMPLATE = "<|im_start|>{user_name}\n{query}{separator}\n<|im_start|>{assistant_name}\n{response}"
22
+ repo = None
23
+
24
+
25
+ def get_total_inputs(inputs, chatbot, preprompt, user_name, assistant_name, sep):
26
+ past = []
27
+ for data in chatbot:
28
+ user_data, model_data = data
29
+
30
+ if not user_data.startswith(user_name):
31
+ user_data = user_name + user_data
32
+ if not model_data.startswith(sep + assistant_name):
33
+ model_data = sep + assistant_name + model_data
34
+
35
+ past.append(user_data + model_data.rstrip() + sep)
36
+
37
+ if not inputs.startswith(user_name):
38
+ inputs = user_name + inputs
39
+
40
+ total_inputs = preprompt + "".join(past) + inputs + sep + assistant_name.rstrip()
41
+
42
+ return total_inputs
43
+
44
+
45
+ def has_no_history(chatbot, history):
46
+ return not chatbot and not history
47
+
48
+
49
+ def generate(
50
+ user_message,
51
+ chatbot,
52
+ history,
53
+ temperature,
54
+ top_p,
55
+ max_new_tokens,
56
+ repetition_penalty,
57
+ header,
58
+ user_name,
59
+ assistant_name,
60
+ separator
61
+ ):
62
+ # Don't return meaningless message when the input is empty
63
+ if not user_message:
64
+ print("Empty input")
65
+
66
+ history.append(user_message)
67
+
68
+ past_messages = []
69
+ for data in chatbot:
70
+ user_data, model_data = data
71
+
72
+ past_messages.extend(
73
+ [{"role": "user", "content": user_data}, {"role": "assistant", "content": model_data.rstrip()}]
74
+ )
75
+
76
+ print(past_messages)
77
+ if len(past_messages) < 1:
78
+ prompt = header + PROMPT_TEMPLATE.format(user_name=user_name,
79
+ query=user_message,
80
+ assistant_name=assistant_name,
81
+ response="",
82
+ separator=separator)
83
+ else:
84
+ prompt = header
85
+ for i in range(0, len(past_messages), 2):
86
+ intermediate_prompt = PROMPT_TEMPLATE.format(user_name=user_name,
87
+ query=past_messages[i]["content"],
88
+ assistant_name=assistant_name,
89
+ response=past_messages[i + 1]["content"],
90
+ separator=separator)
91
+ # print(prompt, separator, intermediate_prompt)
92
+ prompt = prompt + intermediate_prompt + separator + "\n"
93
+
94
+ # print(prompt)
95
+ prompt = prompt + PROMPT_TEMPLATE.format(user_name=user_name,
96
+ query=user_message,
97
+ assistant_name=assistant_name,
98
+ response="",
99
+ separator=separator)
100
+
101
+ temperature = float(temperature)
102
+ if temperature < 1e-2:
103
+ temperature = 1e-2
104
+ top_p = float(top_p)
105
+
106
+ generate_kwargs = dict(
107
+ temperature=temperature,
108
+ max_new_tokens=max_new_tokens,
109
+ top_p=top_p,
110
+ top_k=40,
111
+ # repetition_penalty=repetition_penalty,
112
+ do_sample=True,
113
+ truncate=1024,
114
+ # seed=42,
115
+ # stop_sequences=[user_name, DEFAULT_SEPARATOR]
116
+ stop_sequences=[DEFAULT_SEPARATOR]
117
+ )
118
+
119
+ print(prompt)
120
+ stream = CURRENT_CLIENT.generate_stream(
121
+ prompt,
122
+ **generate_kwargs,
123
+ )
124
+
125
+ output = ""
126
+ for idx, response in enumerate(stream):
127
+ # print(response.token)
128
+ if response.token.text == '':
129
+ pass
130
+ # print(response.token.text)
131
+ # break
132
+
133
+ if response.token.special:
134
+ continue
135
+ output += response.token.text
136
+ if idx == 0:
137
+ history.append(" " + output)
138
+ else:
139
+ history[-1] = output
140
+
141
+ chat = [(history[i].strip(), history[i + 1].strip()) for i in range(0, len(history) - 1, 2)]
142
+ # chat = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)]
143
+
144
+ yield chat, history, user_message, ""
145
+
146
+ return chat, history, user_message, ""
147
+
148
+
149
+ def clear_chat():
150
+ return [], []
151
+
152
+
153
+ title = """<h1 align="center">CroissantLLMChat Playground 🥐</h1>"""
154
+ custom_css = """
155
+ #banner-image {
156
+ display: block;
157
+ margin-left: auto;
158
+ margin-right: auto;
159
+ }
160
+ #chat-message {
161
+ font-size: 14px;
162
+ min-height: 300px;
163
+ }
164
+ """
165
+
166
+ with gr.Blocks(analytics_enabled=False, css=custom_css) as demo:
167
+ gr.HTML(title)
168
+
169
+ with gr.Row():
170
+ with gr.Column():
171
+ gr.Markdown(
172
+ """
173
+ Demo platform for 🥐 CroissantLLMChat. Model is of small size and can hallucinate and generate incorrect or even toxic content.
174
+ """
175
+ )
176
+
177
+ with gr.Row():
178
+ with gr.Box():
179
+ output = gr.Markdown()
180
+ chatbot = gr.Chatbot(elem_id="chat-message", label="Chat")
181
+
182
+ with gr.Row():
183
+ with gr.Column(scale=3):
184
+ user_message = gr.Textbox(placeholder="Enter your message here", show_label=False, elem_id="q-input")
185
+ with gr.Row():
186
+ send_button = gr.Button("Send", elem_id="send-btn", visible=True)
187
+
188
+ clear_chat_button = gr.Button("Clear chat", elem_id="clear-btn", visible=True)
189
+
190
+ with gr.Accordion(label="Parameters", open=False, elem_id="parameters-accordion"):
191
+ temperature = gr.Slider(
192
+ label="Temperature",
193
+ value=0.5,
194
+ minimum=0.1,
195
+ maximum=1.0,
196
+ step=0.1,
197
+ interactive=True,
198
+ info="Higher values produce more diverse outputs",
199
+ )
200
+ top_p = gr.Slider(
201
+ label="Top-p (nucleus sampling)",
202
+ value=0.9,
203
+ minimum=0.0,
204
+ maximum=1,
205
+ step=0.05,
206
+ interactive=True,
207
+ info="Higher values sample more low-probability tokens",
208
+ )
209
+ max_new_tokens = gr.Slider(
210
+ label="Max new tokens",
211
+ value=512,
212
+ minimum=0,
213
+ maximum=1024,
214
+ step=4,
215
+ interactive=True,
216
+ info="The maximum numbers of new tokens",
217
+ )
218
+ repetition_penalty = gr.Slider(
219
+ label="Repetition Penalty",
220
+ value=1.2,
221
+ minimum=0.0,
222
+ maximum=10,
223
+ step=0.1,
224
+ interactive=True,
225
+ info="The parameter for repetition penalty. 1.0 means no penalty.",
226
+ )
227
+ with gr.Accordion(label="Prompt", open=False, elem_id="prompt-accordion"):
228
+ header = gr.Textbox(
229
+ label="Header instructions",
230
+ value=DEFAULT_HEADER,
231
+ interactive=True,
232
+ info="Instructions given to the assistant at the beginning of the prompt",
233
+ )
234
+ user_name = gr.Textbox(
235
+ label="User name",
236
+ value=DEFAULT_USER_NAME,
237
+ interactive=True,
238
+ info="Name to be given to the user in the prompt",
239
+ )
240
+ assistant_name = gr.Textbox(
241
+ label="Assistant name",
242
+ value=DEFAULT_ASSISTANT_NAME,
243
+ interactive=True,
244
+ info="Name to be given to the assistant in the prompt",
245
+ )
246
+ separator = gr.Textbox(
247
+ label="Separator",
248
+ value=DEFAULT_SEPARATOR,
249
+ interactive=True,
250
+ info="Character to be used when the speaker changes in the prompt",
251
+ )
252
+
253
+ history = gr.State([])
254
+ last_user_message = gr.State("")
255
+
256
+ user_message.submit(
257
+ generate,
258
+ inputs=[
259
+ user_message,
260
+ chatbot,
261
+ history,
262
+ temperature,
263
+ top_p,
264
+ max_new_tokens,
265
+ repetition_penalty,
266
+ header,
267
+ user_name,
268
+ assistant_name,
269
+ separator
270
+ ],
271
+ outputs=[chatbot, history, last_user_message, user_message],
272
+ )
273
+
274
+ send_button.click(
275
+ generate,
276
+ inputs=[
277
+ user_message,
278
+ chatbot,
279
+ history,
280
+ temperature,
281
+ top_p,
282
+ max_new_tokens,
283
+ repetition_penalty,
284
+ header,
285
+ user_name,
286
+ assistant_name,
287
+ separator
288
+ ],
289
+ outputs=[chatbot, history, last_user_message, user_message],
290
+ )
291
+
292
+ clear_chat_button.click(clear_chat, outputs=[chatbot, history])
293
+
294
+ demo.queue(concurrency_count=16).launch(server_port=8001)