Nick088 commited on
Commit
b405c3d
1 Parent(s): 0c21826

add text to speech tab

Browse files

add text to speech tab (whisper-large-v3 for transcription and translation) + other functions that help this one (check file, split, merge)

Files changed (1) hide show
  1. app.py +465 -39
app.py CHANGED
@@ -1,12 +1,50 @@
1
  import os
 
2
  import random
 
 
 
 
 
3
  import gradio as gr
 
4
  from groq import Groq
5
 
6
- client = Groq(
7
- api_key = os.environ.get("Groq_Api_Key")
8
- )
 
 
 
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  def create_history_messages(history):
11
  history_messages = [{"role": "user", "content": m[0]} for m in history]
12
  history_messages.extend([{"role": "assistant", "content": m[1]} for m in history])
@@ -18,40 +56,428 @@ def generate_response(prompt, history, model, temperature, max_tokens, top_p, se
18
  print(messages)
19
 
20
  if seed == 0:
21
- seed = random.randint(1, 100000)
22
-
23
- stream = client.chat.completions.create(
24
- messages=messages,
25
- model=model,
26
- temperature=temperature,
27
- max_tokens=max_tokens,
28
- top_p=top_p,
29
- seed=seed,
30
- stop=None,
31
- stream=True,
32
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
- response = ""
35
- for chunk in stream:
36
- delta_content = chunk.choices[0].delta.content
37
- if delta_content is not None:
38
- response += delta_content
39
- yield response
40
-
41
- return response
42
-
43
- additional_inputs = [
44
- gr.Dropdown(choices=["llama3-70b-8192", "llama3-8b-8192", "mixtral-8x7b-32768", "gemma-7b-it", "gemma2-9b-it"], value="llama3-70b-8192", label="Model"),
45
- gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Temperature", info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative."),
46
- gr.Slider(minimum=1, maximum=32192, step=1, value=4096, label="Max Tokens", info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b."),
47
- gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.5, label="Top P", info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p."),
48
- gr.Number(precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random")
49
- ]
50
-
51
- gr.ChatInterface(
52
- fn=generate_response,
53
- chatbot=gr.Chatbot(show_label=False, show_share_button=False, show_copy_button=True, likeable=True, layout="panel"),
54
- additional_inputs=additional_inputs,
55
- title="Groq API UI",
56
- description="Inference by Groq. Hugging Face Space by [Nick088](https://linktr.ee/Nick088)",
57
- ).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
+ import subprocess
3
  import random
4
+ import numpy as np
5
+ import json
6
+ from datetime import timedelta
7
+ import tempfile
8
+ import re
9
  import gradio as gr
10
+ import groq
11
  from groq import Groq
12
 
13
+
14
+ # setup groq
15
+
16
+ client = Groq(api_key=os.environ.get("Groq_Api_Key"))
17
+
18
+ def handle_groq_error(e, model_name):
19
+ error_data = e.args[0]
20
 
21
+ if isinstance(error_data, str):
22
+ # Use regex to extract the JSON part of the string
23
+ json_match = re.search(r'(\{.*\})', error_data)
24
+ if json_match:
25
+ json_str = json_match.group(1)
26
+ # Ensure the JSON string is well-formed
27
+ json_str = json_str.replace("'", '"') # Replace single quotes with double quotes
28
+ error_data = json.loads(json_str)
29
+
30
+ if isinstance(e, groq.RateLimitError):
31
+ if isinstance(error_data, dict) and 'error' in error_data and 'message' in error_data['error']:
32
+ error_message = error_data['error']['message']
33
+ raise gr.Error(error_message)
34
+ else:
35
+ raise gr.Error(f"Error during Groq API call: {e}")
36
+
37
+
38
+ # llms
39
+
40
+ MAX_SEED = np.iinfo(np.int32).max
41
+
42
+ def update_max_tokens(model):
43
+ if model in ["llama3-70b-8192", "llama3-8b-8192", "gemma-7b-it", "gemma2-9b-it"]:
44
+ return gr.update(maximum=8192)
45
+ elif model == "mixtral-8x7b-32768":
46
+ return gr.update(maximum=32768)
47
+
48
  def create_history_messages(history):
49
  history_messages = [{"role": "user", "content": m[0]} for m in history]
50
  history_messages.extend([{"role": "assistant", "content": m[1]} for m in history])
 
56
  print(messages)
57
 
58
  if seed == 0:
59
+ seed = random.randint(1, MAX_SEED)
60
+
61
+ try:
62
+ stream = client.chat.completions.create(
63
+ messages=messages,
64
+ model=model,
65
+ temperature=temperature,
66
+ max_tokens=max_tokens,
67
+ top_p=top_p,
68
+ seed=seed,
69
+ stop=None,
70
+ stream=True,
71
+ )
72
+
73
+ response = ""
74
+ for chunk in stream:
75
+ delta_content = chunk.choices[0].delta.content
76
+ if delta_content is not None:
77
+ response += delta_content
78
+ yield response
79
+
80
+ return response
81
+ except Groq.GroqApiException as e:
82
+ handle_groq_error(e, model)
83
+
84
+ # speech to text
85
+
86
+ ALLOWED_FILE_EXTENSIONS = ["mp3", "mp4", "mpeg", "mpga", "m4a", "wav", "webm"]
87
+ MAX_FILE_SIZE_MB = 25
88
+ CHUNK_SIZE_MB = 25
89
+
90
+ LANGUAGE_CODES = {
91
+ "English": "en",
92
+ "Chinese": "zh",
93
+ "German": "de",
94
+ "Spanish": "es",
95
+ "Russian": "ru",
96
+ "Korean": "ko",
97
+ "French": "fr",
98
+ "Japanese": "ja",
99
+ "Portuguese": "pt",
100
+ "Turkish": "tr",
101
+ "Polish": "pl",
102
+ "Catalan": "ca",
103
+ "Dutch": "nl",
104
+ "Arabic": "ar",
105
+ "Swedish": "sv",
106
+ "Italian": "it",
107
+ "Indonesian": "id",
108
+ "Hindi": "hi",
109
+ "Finnish": "fi",
110
+ "Vietnamese": "vi",
111
+ "Hebrew": "he",
112
+ "Ukrainian": "uk",
113
+ "Greek": "el",
114
+ "Malay": "ms",
115
+ "Czech": "cs",
116
+ "Romanian": "ro",
117
+ "Danish": "da",
118
+ "Hungarian": "hu",
119
+ "Tamil": "ta",
120
+ "Norwegian": "no",
121
+ "Thai": "th",
122
+ "Urdu": "ur",
123
+ "Croatian": "hr",
124
+ "Bulgarian": "bg",
125
+ "Lithuanian": "lt",
126
+ "Latin": "la",
127
+ "Māori": "mi",
128
+ "Malayalam": "ml",
129
+ "Welsh": "cy",
130
+ "Slovak": "sk",
131
+ "Telugu": "te",
132
+ "Persian": "fa",
133
+ "Latvian": "lv",
134
+ "Bengali": "bn",
135
+ "Serbian": "sr",
136
+ "Azerbaijani": "az",
137
+ "Slovenian": "sl",
138
+ "Kannada": "kn",
139
+ "Estonian": "et",
140
+ "Macedonian": "mk",
141
+ "Breton": "br",
142
+ "Basque": "eu",
143
+ "Icelandic": "is",
144
+ "Armenian": "hy",
145
+ "Nepali": "ne",
146
+ "Mongolian": "mn",
147
+ "Bosnian": "bs",
148
+ "Kazakh": "kk",
149
+ "Albanian": "sq",
150
+ "Swahili": "sw",
151
+ "Galician": "gl",
152
+ "Marathi": "mr",
153
+ "Panjabi": "pa",
154
+ "Sinhala": "si",
155
+ "Khmer": "km",
156
+ "Shona": "sn",
157
+ "Yoruba": "yo",
158
+ "Somali": "so",
159
+ "Afrikaans": "af",
160
+ "Occitan": "oc",
161
+ "Georgian": "ka",
162
+ "Belarusian": "be",
163
+ "Tajik": "tg",
164
+ "Sindhi": "sd",
165
+ "Gujarati": "gu",
166
+ "Amharic": "am",
167
+ "Yiddish": "yi",
168
+ "Lao": "lo",
169
+ "Uzbek": "uz",
170
+ "Faroese": "fo",
171
+ "Haitian": "ht",
172
+ "Pashto": "ps",
173
+ "Turkmen": "tk",
174
+ "Norwegian Nynorsk": "nn",
175
+ "Maltese": "mt",
176
+ "Sanskrit": "sa",
177
+ "Luxembourgish": "lb",
178
+ "Burmese": "my",
179
+ "Tibetan": "bo",
180
+ "Tagalog": "tl",
181
+ "Malagasy": "mg",
182
+ "Assamese": "as",
183
+ "Tatar": "tt",
184
+ "Hawaiian": "haw",
185
+ "Lingala": "ln",
186
+ "Hausa": "ha",
187
+ "Bashkir": "ba",
188
+ "jw": "jw",
189
+ "Sundanese": "su",
190
+ }
191
 
192
+
193
+ def split_audio(audio_file_path, chunk_size_mb):
194
+ chunk_size = chunk_size_mb * 1024 * 1024 # Convert MB to bytes
195
+ file_number = 1
196
+ chunks = []
197
+ with open(audio_file_path, 'rb') as f:
198
+ chunk = f.read(chunk_size)
199
+ while chunk:
200
+ chunk_name = f"{os.path.splitext(audio_file_path)[0]}_part{file_number:03}.mp3" # Pad file number for correct ordering
201
+ with open(chunk_name, 'wb') as chunk_file:
202
+ chunk_file.write(chunk)
203
+ chunks.append(chunk_name)
204
+ file_number += 1
205
+ chunk = f.read(chunk_size)
206
+ return chunks
207
+
208
+ def merge_audio(chunks, output_file_path):
209
+ with open("temp_list.txt", "w") as f:
210
+ for file in chunks:
211
+ f.write(f"file '{file}'\n")
212
+ try:
213
+ subprocess.run(
214
+ [
215
+ "ffmpeg",
216
+ "-f",
217
+ "concat",
218
+ "-safe", "0",
219
+ "-i",
220
+ "temp_list.txt",
221
+ "-c",
222
+ "copy",
223
+ "-y",
224
+ output_file_path
225
+ ],
226
+ check=True
227
+ )
228
+ os.remove("temp_list.txt")
229
+ for chunk in chunks:
230
+ os.remove(chunk)
231
+ except subprocess.CalledProcessError as e:
232
+ raise gr.Error(f"Error during audio merging: {e}")
233
+
234
+ # Checks file extension, size, and downsamples or splits if needed.
235
+ def check_file(audio_file_path):
236
+ if not audio_file_path:
237
+ raise gr.Error("Please upload an audio file.")
238
+
239
+ file_size_mb = os.path.getsize(audio_file_path) / (1024 * 1024)
240
+ file_extension = audio_file_path.split(".")[-1].lower()
241
+
242
+ if file_extension not in ALLOWED_FILE_EXTENSIONS:
243
+ raise gr.Error(f"Invalid file type (.{file_extension}). Allowed types: {', '.join(ALLOWED_FILE_EXTENSIONS)}")
244
+
245
+ if file_size_mb > MAX_FILE_SIZE_MB:
246
+ gr.Warning(
247
+ f"File size too large ({file_size_mb:.2f} MB). Attempting to downsample to 16kHz MP3 128kbps. Maximum size allowed: {MAX_FILE_SIZE_MB} MB"
248
+ )
249
+
250
+ output_file_path = os.path.splitext(audio_file_path)[0] + "_downsampled.mp3"
251
+ try:
252
+ subprocess.run(
253
+ [
254
+ "ffmpeg",
255
+ "-i",
256
+ audio_file_path,
257
+ "-ar",
258
+ "16000",
259
+ "-ab",
260
+ "128k",
261
+ "-ac",
262
+ "1",
263
+ "-y",
264
+ output_file_path,
265
+ ],
266
+ check=True
267
+ )
268
+
269
+ # Check size after downsampling
270
+ downsampled_size_mb = os.path.getsize(output_file_path) / (1024 * 1024)
271
+ if downsampled_size_mb > MAX_FILE_SIZE_MB:
272
+ gr.Warning(f"File still too large after downsampling ({downsampled_size_mb:.2f} MB). Splitting into {CHUNK_SIZE_MB} MB chunks.")
273
+ return split_audio(output_file_path, CHUNK_SIZE_MB), "split"
274
+
275
+ return output_file_path, None
276
+ except subprocess.CalledProcessError as e:
277
+ raise gr.Error(f"Error during downsampling: {e}")
278
+ return audio_file_path, None
279
+
280
+
281
+ def transcribe_audio(audio_file_path, model, prompt, language, auto_detect_language):
282
+ processed_path, split_status = check_file(audio_file_path)
283
+ full_transcription = ""
284
+
285
+ if split_status == "split":
286
+ processed_chunks = []
287
+ for i, chunk_path in enumerate(processed_path):
288
+ try:
289
+ with open(chunk_path, "rb") as file:
290
+ transcription = client.audio.transcriptions.create(
291
+ file=(os.path.basename(chunk_path), file.read()),
292
+ model=model,
293
+ prompt=prompt,
294
+ response_format="text",
295
+ language=None if auto_detect_language else language,
296
+ temperature=0.0,
297
+ )
298
+ full_transcription += transcription
299
+ processed_chunks.append(chunk_path)
300
+ except groq.RateLimitError as e: # Handle rate limit error
301
+ handle_groq_error(e, model)
302
+ gr.Warning(f"API limit reached during chunk {i+1}. Returning processed chunks only.")
303
+ if processed_chunks:
304
+ merge_audio(processed_chunks, 'merged_output.mp3')
305
+ return full_transcription, 'merged_output.mp3'
306
+ else:
307
+ return "Transcription failed due to API limits.", None
308
+ merge_audio(processed_path, 'merged_output.mp3')
309
+ return full_transcription, 'merged_output.mp3'
310
+ else:
311
+ try:
312
+ with open(processed_path, "rb") as file:
313
+ transcription = client.audio.transcriptions.create(
314
+ file=(os.path.basename(processed_path), file.read()),
315
+ model=model,
316
+ prompt=prompt,
317
+ response_format="text",
318
+ language=None if auto_detect_language else language,
319
+ temperature=0.0,
320
+ )
321
+ return transcription.text, None
322
+ except groq.RateLimitError as e: # Handle rate limit error
323
+ handle_groq_error(e, model)
324
+
325
+ def translate_audio(audio_file_path, model, prompt):
326
+ processed_path, split_status = check_file(audio_file_path)
327
+ full_translation = ""
328
+
329
+ if split_status == "split":
330
+ for chunk_path in processed_path:
331
+ try:
332
+ with open(chunk_path, "rb") as file:
333
+ translation = client.audio.translations.create(
334
+ file=(os.path.basename(chunk_path), file.read()),
335
+ model=model,
336
+ prompt=prompt,
337
+ response_format="text",
338
+ temperature=0.0,
339
+ )
340
+ full_translation += translation
341
+ except Groq.GroqApiException as e:
342
+ handle_groq_error(e, model)
343
+ return f"API limit reached. Partial translation: {full_translation}"
344
+ return full_translation
345
+ else:
346
+ try:
347
+ with open(processed_path, "rb") as file:
348
+ translation = client.audio.translations.create(
349
+ file=(os.path.basename(processed_path), file.read()),
350
+ model=model,
351
+ prompt=prompt,
352
+ response_format="text",
353
+ temperature=0.0,
354
+ )
355
+ return translation
356
+ except Groq.GroqApiException as e:
357
+ handle_groq_error(e, model)
358
+
359
+
360
+ with gr.Blocks() as interface:
361
+ gr.Markdown(
362
+ """
363
+ # Groq API UI
364
+ Inference by Groq API
365
+ If you are having API Rate Limit issues, you can retry later based on the [rate limits](https://console.groq.com/docs/rate-limits) or <a href="https://huggingface.co/spaces/Nick088/Fast-Subtitle-Maker?duplicate=true" style="display: inline-block;margin-top: .5em;margin-right: .25em;" target="_blank"> <img style="margin-bottom: 0em;display: inline;margin-top: -.25em;" src="https://img.shields.io/badge/-Duplicate%20Space-blue?labelColor=white&style=flat&logo=&logoWidth=14" alt="Duplicate Space"></a> with <a href=https://console.groq.com/keys>your own API Key</a> </p>
366
+ Hugging Face Space by [Nick088](https://linktr.ee/Nick088)
367
+ """
368
+ )
369
+ with gr.Tabs():
370
+ with gr.TabItem("LLMs"):
371
+ with gr.Row():
372
+ with gr.Column(scale=1, min_width=250):
373
+ model = gr.Dropdown(
374
+ choices=[
375
+ "llama3-70b-8192",
376
+ "llama3-8b-8192",
377
+ "mixtral-8x7b-32768",
378
+ "gemma-7b-it",
379
+ "gemma2-9b-it",
380
+ ],
381
+ value="llama3-70b-8192",
382
+ label="Model",
383
+ )
384
+ temperature = gr.Slider(
385
+ minimum=0.0,
386
+ maximum=1.0,
387
+ step=0.01,
388
+ value=0.5,
389
+ label="Temperature",
390
+ info="Controls diversity of the generated text. Lower is more deterministic, higher is more creative.",
391
+ )
392
+ max_tokens = gr.Slider(
393
+ minimum=1,
394
+ maximum=8192,
395
+ step=1,
396
+ value=4096,
397
+ label="Max Tokens",
398
+ info="The maximum number of tokens that the model can process in a single response.<br>Maximums: 8k for gemma 7b it, gemma2 9b it, llama 7b & 70b, 32k for mixtral 8x7b.",
399
+ )
400
+ top_p = gr.Slider(
401
+ minimum=0.0,
402
+ maximum=1.0,
403
+ step=0.01,
404
+ value=0.5,
405
+ label="Top P",
406
+ info="A method of text generation where a model will only consider the most probable next tokens that make up the probability p.",
407
+ )
408
+ seed = gr.Number(
409
+ precision=0, value=42, label="Seed", info="A starting point to initiate generation, use 0 for random"
410
+ )
411
+ model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
412
+ with gr.Column(scale=1, min_width=400):
413
+ chatbot = gr.ChatInterface(
414
+ fn=generate_response,
415
+ chatbot=None,
416
+ additional_inputs=[
417
+ model,
418
+ temperature,
419
+ max_tokens,
420
+ top_p,
421
+ seed,
422
+ ],
423
+ )
424
+ model.change(update_max_tokens, inputs=[model], outputs=max_tokens)
425
+ with gr.TabItem("Speech To Text"):
426
+ with gr.Tabs():
427
+ with gr.TabItem("Transcription"):
428
+ gr.Markdown("Transcript audio from files to text!")
429
+ with gr.Row():
430
+ audio_input = gr.File(
431
+ type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
432
+ )
433
+ model_choice_transcribe = gr.Dropdown(
434
+ choices=["whisper-large-v3"], # Only include 'whisper-large-v3'
435
+ value="whisper-large-v3",
436
+ label="Model",
437
+ )
438
+ with gr.Row():
439
+ transcribe_prompt = gr.Textbox(
440
+ label="Prompt (Optional)",
441
+ info="Specify any context or spelling corrections.",
442
+ )
443
+ with gr.Column():
444
+ language = gr.Dropdown(
445
+ choices=[(lang, code) for lang, code in LANGUAGE_CODES.items()],
446
+ value="en",
447
+ label="Language",
448
+ )
449
+ auto_detect_language = gr.Checkbox(label="Auto Detect Language")
450
+ transcribe_button = gr.Button("Transcribe")
451
+ transcription_output = gr.Textbox(label="Transcription")
452
+ merged_audio_output = gr.File(label="Merged Audio (if chunked)")
453
+ transcribe_button.click(
454
+ transcribe_audio,
455
+ inputs=[audio_input, model_choice_transcribe, transcribe_prompt, language, auto_detect_language],
456
+ outputs=[transcription_output, merged_audio_output],
457
+ )
458
+ with gr.TabItem("Translation"):
459
+ gr.Markdown("Transcript audio from files and translate them to English text!")
460
+ with gr.Row():
461
+ audio_input_translate = gr.File(
462
+ type="filepath", label="Upload File containing Audio", file_types=[f".{ext}" for ext in ALLOWED_FILE_EXTENSIONS]
463
+ )
464
+ model_choice_translate = gr.Dropdown(
465
+ choices=["whisper-large-v3"], # Only include 'whisper-large-v3'
466
+ value="whisper-large-v3",
467
+ label="Model",
468
+ )
469
+ with gr.Row():
470
+ translate_prompt = gr.Textbox(
471
+ label="Prompt (Optional)",
472
+ info="Specify any context or spelling corrections.",
473
+ )
474
+ translate_button = gr.Button("Translate")
475
+ translation_output = gr.Textbox(label="Translation")
476
+ translate_button.click(
477
+ translate_audio,
478
+ inputs=[audio_input_translate, model_choice_translate, translate_prompt],
479
+ outputs=translation_output,
480
+ )
481
+
482
+
483
+ interface.launch(share=True)