kevinwang676 commited on
Commit
5e06908
1 Parent(s): 5e540eb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +391 -11
app.py CHANGED
@@ -99,17 +99,397 @@ def convert(model, src, tgt):
99
  write("out.wav", 24000, audio)
100
  out = "out.wav"
101
  return out
102
-
103
- model = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC",type="value", label="Model")
104
- audio1 = gr.inputs.Audio(label="Source Audio", type='filepath')
105
- audio2 = gr.inputs.Audio(label="Reference Audio", type='filepath')
106
- inputs = [model, audio1, audio2]
107
- outputs = gr.outputs.Audio(label="Output Audio", type='filepath')
108
 
109
- title = "FreeVC"
110
- description = "Gradio Demo for FreeVC: Towards High-Quality Text-Free One-Shot Voice Conversion. To use it, simply upload your audio, or click the example to load. Read more at the links below. Note: It seems that the WavLM checkpoint in HuggingFace is a little different from the one used to train FreeVC, which may degrade the performance a bit. In addition, speaker similarity can be largely affected if there are too much silence in the reference audio, so please <strong>trim</strong> it before submitting."
111
- article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2210.15418' target='_blank'>Paper</a> | <a href='https://github.com/OlaWod/FreeVC' target='_blank'>Github Repo</a></p>"
112
 
113
- examples=[["FreeVC", 'p225_001.wav', 'p226_002.wav'], ["FreeVC-s", 'p226_002.wav', 'p225_001.wav'], ["FreeVC (24kHz)", 'p225_001.wav', 'p226_002.wav']]
114
 
115
- gr.Interface(convert, inputs, outputs, title=title, description=description, article=article, examples=examples, enable_queue=True).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
  write("out.wav", 24000, audio)
100
  out = "out.wav"
101
  return out
 
 
 
 
 
 
102
 
103
+ # GLM2
 
 
104
 
105
+ language_dict = tts_order_voice
106
 
107
+ # fix timezone in Linux
108
+ os.environ["TZ"] = "Asia/Shanghai"
109
+ try:
110
+ time.tzset() # type: ignore # pylint: disable=no-member
111
+ except Exception:
112
+ # Windows
113
+ logger.warning("Windows, cant run time.tzset()")
114
+
115
+ # model_name = "THUDM/chatglm2-6b"
116
+ model_name = "THUDM/chatglm2-6b-int4"
117
+
118
+ RETRY_FLAG = False
119
+
120
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
121
+
122
+ # model = AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda()
123
+
124
+ # 4/8 bit
125
+ # model = AutoModel.from_pretrained("THUDM/chatglm2-6b", trust_remote_code=True).quantize(4).cuda()
126
+
127
+ has_cuda = torch.cuda.is_available()
128
+
129
+ # has_cuda = False # force cpu
130
+
131
+ if has_cuda:
132
+ model_glm = (
133
+ AutoModel.from_pretrained(model_name, trust_remote_code=True).cuda().half()
134
+ ) # 3.92G
135
+ else:
136
+ model_glm = AutoModel.from_pretrained(
137
+ model_name, trust_remote_code=True
138
+ ).float() # .float() .half().float()
139
+
140
+ model_glm = model_glm.eval()
141
+
142
+ _ = """Override Chatbot.postprocess"""
143
+
144
+
145
+ def postprocess(self, y):
146
+ if y is None:
147
+ return []
148
+ for i, (message, response) in enumerate(y):
149
+ y[i] = (
150
+ None if message is None else mdtex2html.convert((message)),
151
+ None if response is None else mdtex2html.convert(response),
152
+ )
153
+ return y
154
+
155
+
156
+ gr.Chatbot.postprocess = postprocess
157
+
158
+
159
+ def parse_text(text):
160
+ """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
161
+ lines = text.split("\n")
162
+ lines = [line for line in lines if line != ""]
163
+ count = 0
164
+ for i, line in enumerate(lines):
165
+ if "```" in line:
166
+ count += 1
167
+ items = line.split("`")
168
+ if count % 2 == 1:
169
+ lines[i] = f'<pre><code class="language-{items[-1]}">'
170
+ else:
171
+ lines[i] = "<br></code></pre>"
172
+ else:
173
+ if i > 0:
174
+ if count % 2 == 1:
175
+ line = line.replace("`", r"\`")
176
+ line = line.replace("<", "&lt;")
177
+ line = line.replace(">", "&gt;")
178
+ line = line.replace(" ", "&nbsp;")
179
+ line = line.replace("*", "&ast;")
180
+ line = line.replace("_", "&lowbar;")
181
+ line = line.replace("-", "&#45;")
182
+ line = line.replace(".", "&#46;")
183
+ line = line.replace("!", "&#33;")
184
+ line = line.replace("(", "&#40;")
185
+ line = line.replace(")", "&#41;")
186
+ line = line.replace("$", "&#36;")
187
+ lines[i] = "<br>" + line
188
+ text = "".join(lines)
189
+ return text
190
+
191
+
192
+ def predict(
193
+ RETRY_FLAG, input, chatbot, max_length, top_p, temperature, history, past_key_values
194
+ ):
195
+ try:
196
+ chatbot.append((parse_text(input), ""))
197
+ except Exception as exc:
198
+ logger.error(exc)
199
+ logger.debug(f"{chatbot=}")
200
+ _ = """
201
+ if chatbot:
202
+ chatbot[-1] = (parse_text(input), str(exc))
203
+ yield chatbot, history, past_key_values
204
+ # """
205
+ yield chatbot, history, past_key_values
206
+
207
+ for response, history, past_key_values in model_glm.stream_chat(
208
+ tokenizer,
209
+ input,
210
+ history,
211
+ past_key_values=past_key_values,
212
+ return_past_key_values=True,
213
+ max_length=max_length,
214
+ top_p=top_p,
215
+ temperature=temperature,
216
+ ):
217
+ chatbot[-1] = (parse_text(input), parse_text(response))
218
+ # chatbot[-1][-1] = parse_text(response)
219
+
220
+ yield chatbot, history, past_key_values, parse_text(response)
221
+
222
+
223
+ def trans_api(input, max_length=4096, top_p=0.8, temperature=0.2):
224
+ if max_length < 10:
225
+ max_length = 4096
226
+ if top_p < 0.1 or top_p > 1:
227
+ top_p = 0.85
228
+ if temperature <= 0 or temperature > 1:
229
+ temperature = 0.01
230
+ try:
231
+ res, _ = model_glm.chat(
232
+ tokenizer,
233
+ input,
234
+ history=[],
235
+ past_key_values=None,
236
+ max_length=max_length,
237
+ top_p=top_p,
238
+ temperature=temperature,
239
+ )
240
+ # logger.debug(f"{res=} \n{_=}")
241
+ except Exception as exc:
242
+ logger.error(f"{exc=}")
243
+ res = str(exc)
244
+
245
+ return res
246
+
247
+
248
+ def reset_user_input():
249
+ return gr.update(value="")
250
+
251
+
252
+ def reset_state():
253
+ return [], [], None, ""
254
+
255
+
256
+ # Delete last turn
257
+ def delete_last_turn(chat, history):
258
+ if chat and history:
259
+ chat.pop(-1)
260
+ history.pop(-1)
261
+ return chat, history
262
+
263
+
264
+ # Regenerate response
265
+ def retry_last_answer(
266
+ user_input, chatbot, max_length, top_p, temperature, history, past_key_values
267
+ ):
268
+ if chatbot and history:
269
+ # Removing the previous conversation from chat
270
+ chatbot.pop(-1)
271
+ # Setting up a flag to capture a retry
272
+ RETRY_FLAG = True
273
+ # Getting last message from user
274
+ user_input = history[-1][0]
275
+ # Removing bot response from the history
276
+ history.pop(-1)
277
+
278
+ yield from predict(
279
+ RETRY_FLAG, # type: ignore
280
+ user_input,
281
+ chatbot,
282
+ max_length,
283
+ top_p,
284
+ temperature,
285
+ history,
286
+ past_key_values,
287
+ )
288
+
289
+ # print
290
+
291
+ def print(text):
292
+ return text
293
+
294
+ # TTS
295
+
296
+ async def text_to_speech_edge(text, language_code):
297
+ voice = language_dict[language_code]
298
+ communicate = edge_tts.Communicate(text, voice)
299
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".mp3") as tmp_file:
300
+ tmp_path = tmp_file.name
301
+
302
+ await communicate.save(tmp_path)
303
+
304
+ return tmp_path
305
+
306
+
307
+ with gr.Blocks(title="ChatGLM2-6B-int4", theme=gr.themes.Soft(text_size="sm")) as demo:
308
+ # gr.HTML("""<h1 align="center">ChatGLM2-6B-int4</h1>""")
309
+ gr.HTML(
310
+ """<center><a href="https://huggingface.co/spaces/mikeee/chatglm2-6b-4bit?duplicate=true"><img src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a>To avoid the queue and for faster inference Duplicate this Space and upgrade to GPU</center>"""
311
+ )
312
+
313
+ with gr.Accordion("🎈 Info", open=False):
314
+ _ = f"""
315
+ ## {model_name}
316
+ Try to refresh the browser and try again when occasionally an error occurs.
317
+ With a GPU, a query takes from a few seconds to a few tens of seconds, dependent on the number of words/characters
318
+ the question and responses contain. The quality of the responses varies quite a bit it seems. Even the same
319
+ question with the same parameters, asked at different times, can result in quite different responses.
320
+ * Low temperature: responses will be more deterministic and focused; High temperature: responses more creative.
321
+ * Suggested temperatures -- translation: up to 0.3; chatting: > 0.4
322
+ * Top P controls dynamic vocabulary selection based on context.
323
+ For a table of example values for different scenarios, refer to [this](https://community.openai.com/t/cheat-sheet-mastering-temperature-and-top-p-in-chatgpt-api-a-few-tips-and-tricks-on-controlling-the-creativity-deterministic-output-of-prompt-responses/172683)
324
+ If the instance is not on a GPU (T4), it will be very slow. You can try to run the colab notebook [chatglm2-6b-4bit colab notebook](https://colab.research.google.com/drive/1WkF7kOjVCcBBatDHjaGkuJHnPdMWNtbW?usp=sharing) for a spin.
325
+ The T4 GPU is sponsored by a community GPU grant from Huggingface. Thanks a lot!
326
+ """
327
+ gr.Markdown(dedent(_))
328
+ chatbot = gr.Chatbot()
329
+ with gr.Row():
330
+ with gr.Column(scale=4):
331
+ with gr.Column(scale=12):
332
+ user_input = gr.Textbox(
333
+ label="请在此处和GLM2聊天 (按回车键即可发送)",
334
+ placeholder="聊点什么吧",
335
+ )
336
+ RETRY_FLAG = gr.Checkbox(value=False, visible=False)
337
+ with gr.Column(min_width=32, scale=1):
338
+ with gr.Row():
339
+ submitBtn = gr.Button("开始和GLM2交流吧", variant="primary")
340
+ deleteBtn = gr.Button("删除最新一轮对话", variant="secondary")
341
+ retryBtn = gr.Button("重新生成最新一轮对话", variant="secondary")
342
+
343
+ with gr.Accordion("更多设置", open=False):
344
+ with gr.Row():
345
+ emptyBtn = gr.Button("清空所有聊天记录")
346
+ max_length = gr.Slider(
347
+ 0,
348
+ 32768,
349
+ value=8192,
350
+ step=1.0,
351
+ label="Maximum length",
352
+ interactive=True,
353
+ )
354
+ top_p = gr.Slider(
355
+ 0, 1, value=0.85, step=0.01, label="Top P", interactive=True
356
+ )
357
+ temperature = gr.Slider(
358
+ 0.01, 1, value=0.95, step=0.01, label="Temperature", interactive=True
359
+ )
360
+
361
+
362
+ with gr.Row():
363
+ test1 = gr.Textbox(label="GLM2的最新回答 (可编辑)", lines = 3)
364
+ with gr.Column():
365
+ language = gr.Dropdown(choices=list(language_dict.keys()), value="普通话 (中国大陆)-Xiaoxiao-女", label="请选择文本对应的语言及您喜欢的说话人")
366
+ tts_btn = gr.Button("生成对应的音频吧", variant="primary")
367
+ output_audio = gr.Audio(type="filepath", label="为您生成的音频")
368
+
369
+ tts_btn.click(text_to_speech_edge, inputs=[test1, language], outputs=[output_audio])
370
+
371
+ with gr.Row():
372
+ model_choice = gr.Dropdown(choices=["FreeVC", "FreeVC-s", "FreeVC (24kHz)"], value="FreeVC (24kHz)", label="Model", visible=False)
373
+ audio1 = output_audio
374
+ audio2 = gr.Audio(label="请上传您喜欢的声音进行声音克隆", type='filepath')
375
+ clone_btn = gr.Button("开始AI声音克隆吧")
376
+ audio_cloned = gr.Audio(label="为您生成的专属声音克隆音频", type='filepath')
377
+
378
+ clone_btn.click(convert, inputs=[model_choice, audio1, audio2], outputs=[audio_cloned])
379
+
380
+ history = gr.State([])
381
+ past_key_values = gr.State(None)
382
+
383
+ user_input.submit(
384
+ predict,
385
+ [
386
+ RETRY_FLAG,
387
+ user_input,
388
+ chatbot,
389
+ max_length,
390
+ top_p,
391
+ temperature,
392
+ history,
393
+ past_key_values,
394
+ ],
395
+ [chatbot, history, past_key_values, test1],
396
+ show_progress="full",
397
+ )
398
+ submitBtn.click(
399
+ predict,
400
+ [
401
+ RETRY_FLAG,
402
+ user_input,
403
+ chatbot,
404
+ max_length,
405
+ top_p,
406
+ temperature,
407
+ history,
408
+ past_key_values,
409
+ ],
410
+ [chatbot, history, past_key_values, test1],
411
+ show_progress="full",
412
+ api_name="predict",
413
+ )
414
+ submitBtn.click(reset_user_input, [], [user_input])
415
+
416
+ emptyBtn.click(
417
+ reset_state, outputs=[chatbot, history, past_key_values, test1], show_progress="full"
418
+ )
419
+
420
+ retryBtn.click(
421
+ retry_last_answer,
422
+ inputs=[
423
+ user_input,
424
+ chatbot,
425
+ max_length,
426
+ top_p,
427
+ temperature,
428
+ history,
429
+ past_key_values,
430
+ ],
431
+ # outputs = [chatbot, history, last_user_message, user_message]
432
+ outputs=[chatbot, history, past_key_values, test1],
433
+ )
434
+ deleteBtn.click(delete_last_turn, [chatbot, history], [chatbot, history])
435
+
436
+ with gr.Accordion("Example inputs", open=False):
437
+ etext = """In America, where cars are an important part of the national psyche, a decade ago people had suddenly started to drive less, which had not happened since the oil shocks of the 1970s. """
438
+ examples = gr.Examples(
439
+ examples=[
440
+ ["Explain the plot of Cinderella in a sentence."],
441
+ [
442
+ "How long does it take to become proficient in French, and what are the best methods for retaining information?"
443
+ ],
444
+ ["What are some common mistakes to avoid when writing code?"],
445
+ ["Build a prompt to generate a beautiful portrait of a horse"],
446
+ ["Suggest four metaphors to describe the benefits of AI"],
447
+ ["Write a pop song about leaving home for the sandy beaches."],
448
+ ["Write a summary demonstrating my ability to tame lions"],
449
+ ["鲁迅和周树人什么关系"],
450
+ ["从前有一头牛,这头牛后面有什么?"],
451
+ ["正无穷大加一大于正无穷大吗?"],
452
+ ["正无穷大加正无穷大大于正无穷大吗?"],
453
+ ["-2的平方根等于什么"],
454
+ ["树上有5只鸟,猎人开枪打死了一只。树上还有几只鸟?"],
455
+ ["树上有11只鸟,猎人开枪打死了一只。树上还有几只鸟?提示:需考虑鸟可能受惊吓飞走。"],
456
+ ["鲁迅和周树人什么关系 用英文回答"],
457
+ ["以红楼梦的行文风格写一张委婉的请假条。不少于320字。"],
458
+ [f"{etext} 翻成中文,列出3个版本"],
459
+ [f"{etext} \n 翻成中文,保留原意,但使用文学性的语言。不要写解释。列出3个版本"],
460
+ ["js 判断一个数是不是质数"],
461
+ ["js 实现python 的 range(10)"],
462
+ ["js 实现python 的 [*(range(10)]"],
463
+ ["假定 1 + 2 = 4, 试求 7 + 8"],
464
+ ["Erkläre die Handlung von Cinderella in einem Satz."],
465
+ ["Erkläre die Handlung von Cinderella in einem Satz. Auf Deutsch"],
466
+ ],
467
+ inputs=[user_input],
468
+ examples_per_page=30,
469
+ )
470
+
471
+ with gr.Accordion("For Chat/Translation API", open=False, visible=False):
472
+ input_text = gr.Text()
473
+ tr_btn = gr.Button("Go", variant="primary")
474
+ out_text = gr.Text()
475
+ tr_btn.click(
476
+ trans_api,
477
+ [input_text, max_length, top_p, temperature],
478
+ out_text,
479
+ # show_progress="full",
480
+ api_name="tr",
481
+ )
482
+ _ = """
483
+ input_text.submit(
484
+ trans_api,
485
+ [input_text, max_length, top_p, temperature],
486
+ out_text,
487
+ show_progress="full",
488
+ api_name="tr1",
489
+ )
490
+ # """
491
+
492
+ # demo.queue().launch(share=False, inbrowser=True)
493
+ # demo.queue().launch(share=True, inbrowser=True, debug=True)
494
+
495
+ demo.queue().launch(show_error=True, debug=True)