ovieyra21 commited on
Commit
6d3b05b
·
verified ·
1 Parent(s): b16790d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +119 -176
app.py CHANGED
@@ -1,20 +1,16 @@
1
- import torch
2
-
3
  import gradio as gr
 
 
4
  import yt_dlp as youtube_dl
 
 
5
  import numpy as np
6
  from datasets import Dataset, Audio
7
- from scipy.io import wavfile
8
-
9
- from transformers import pipeline
10
- from transformers.pipelines.audio_utils import ffmpeg_read
11
-
12
  import tempfile
13
- import os
14
  import time
15
- import demucs.api
16
 
17
- MODEL_NAME = "openai/whisper-large-v3" # "patrickvonplaten/wav2vec2-large-960h-lv60-self-4-gram" #
18
  DEMUCS_MODEL_NAME = "htdemucs_ft"
19
  BATCH_SIZE = 8
20
  FILE_LIMIT_MB = 1000
@@ -29,34 +25,29 @@ pipe = pipeline(
29
  device=device,
30
  )
31
 
32
- separator = demucs.api.Separator(model = DEMUCS_MODEL_NAME, )
33
 
34
  def separate_vocal(path):
35
- origin, separated = separator.separate_audio_file(path)
36
- demucs.api.save_audio(separated["vocals"], path, samplerate=separator.samplerate)
37
- return path
 
38
 
39
-
40
- def transcribe(inputs_path, task, use_demucs, dataset_name, oauth_token: gr.OAuthToken | None, progress=gr.Progress()):
41
  if inputs_path is None:
42
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
43
- if dataset_name is None:
44
  raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
45
-
46
  if oauth_token is None:
47
- gr.Warning("Make sure to click and login before using this demo.")
48
- return [["transcripts will appear here"]], ""
49
-
50
  total_step = 4
51
  current_step = 0
52
 
53
  current_step += 1
54
  progress((current_step, total_step), desc="Transcribe using Whisper.")
55
-
56
- sampling_rate, inputs = wavfile.read(inputs_path)
57
-
58
  out = pipe(inputs_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
59
-
60
  text = out["text"]
61
 
62
  current_step += 1
@@ -65,101 +56,49 @@ def transcribe(inputs_path, task, use_demucs, dataset_name, oauth_token: gr.OAut
65
 
66
  current_step += 1
67
  progress((current_step, total_step), desc="Create dataset.")
68
-
69
-
70
  transcripts = []
71
  audios = []
72
  with tempfile.TemporaryDirectory() as tmpdirname:
73
- for i,chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for)")):
74
-
75
- # TODO: make sure 1D or 2D?
76
  arr = chunk["audio"]
77
  path = os.path.join(tmpdirname, f"{i}.wav")
78
- wavfile.write(path, sampling_rate, arr)
79
 
80
  if use_demucs == "separate-audio":
81
- # use demucs tp separate vocals
82
  print(f"Separating vocals #{i}")
83
  path = separate_vocal(path)
84
 
85
  audios.append(path)
86
  transcripts.append(chunk["text"])
87
-
88
  dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
89
 
90
  current_step += 1
91
  progress((current_step, total_step), desc="Push dataset.")
92
- dataset.push_to_hub(dataset_name, token=oauth_token.token if oauth_token else oauth_token)
93
 
94
  return [[transcript] for transcript in transcripts], text
95
 
96
-
97
- def _return_yt_html_embed(yt_url):
98
- video_id = yt_url.split("?v=")[-1]
99
- HTML_str = (
100
- f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
101
- " </center>"
102
- )
103
- return HTML_str
104
-
105
- def download_yt_audio(yt_url, filename):
106
- info_loader = youtube_dl.YoutubeDL()
107
-
108
- try:
109
- info = info_loader.extract_info(yt_url, download=False)
110
- except youtube_dl.utils.DownloadError as err:
111
- raise gr.Error(str(err))
112
-
113
- file_length = info["duration_string"]
114
- file_h_m_s = file_length.split(":")
115
- file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
116
-
117
- if len(file_h_m_s) == 1:
118
- file_h_m_s.insert(0, 0)
119
- if len(file_h_m_s) == 2:
120
- file_h_m_s.insert(0, 0)
121
- file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
122
-
123
- if file_length_s > YT_LENGTH_LIMIT_S:
124
- yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
125
- file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
126
- raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
127
-
128
- ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
129
-
130
- with youtube_dl.YoutubeDL(ydl_opts) as ydl:
131
- try:
132
- ydl.download([yt_url])
133
- except youtube_dl.utils.ExtractorError as err:
134
- raise gr.Error(str(err))
135
-
136
-
137
- def yt_transcribe(yt_url, task, use_demucs, dataset_name, oauth_token: gr.OAuthToken | None, max_filesize=75.0, dataset_sampling_rate = 24000,
138
- progress=gr.Progress()):
139
-
140
  if yt_url is None:
141
- raise gr.Error("No youtube link submitted! Please put a working link.")
142
- if dataset_name is None:
143
  raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
 
 
144
 
145
  total_step = 5
146
  current_step = 0
147
 
148
  html_embed_str = _return_yt_html_embed(yt_url)
149
 
150
- if oauth_token is None:
151
- gr.Warning("Make sure to click and login before using this demo.")
152
- return html_embed_str, [["transcripts will appear here"]], ""
153
-
154
  current_step += 1
155
  progress((current_step, total_step), desc="Load video.")
156
 
157
  with tempfile.TemporaryDirectory() as tmpdirname:
158
  filepath = os.path.join(tmpdirname, "video.mp4")
159
-
160
  download_yt_audio(yt_url, filepath)
161
- with open(filepath, "rb") as f:
162
- inputs_path = f.read()
163
 
164
  inputs = ffmpeg_read(inputs_path, pipe.feature_extractor.sampling_rate)
165
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
@@ -167,7 +106,6 @@ def yt_transcribe(yt_url, task, use_demucs, dataset_name, oauth_token: gr.OAuthT
167
  current_step += 1
168
  progress((current_step, total_step), desc="Transcribe using Whisper.")
169
  out = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
170
-
171
  text = out["text"]
172
 
173
  inputs = ffmpeg_read(inputs_path, dataset_sampling_rate)
@@ -178,135 +116,140 @@ def yt_transcribe(yt_url, task, use_demucs, dataset_name, oauth_token: gr.OAuthT
178
 
179
  current_step += 1
180
  progress((current_step, total_step), desc="Create dataset.")
181
-
182
  transcripts = []
183
  audios = []
184
  with tempfile.TemporaryDirectory() as tmpdirname:
185
- for i,chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for).")):
186
-
187
- # TODO: make sure 1D or 2D?
188
  arr = chunk["audio"]
189
  path = os.path.join(tmpdirname, f"{i}.wav")
190
- wavfile.write(path, dataset_sampling_rate, arr)
191
 
192
  if use_demucs == "separate-audio":
193
- # use demucs tp separate vocals
194
  print(f"Separating vocals #{i}")
195
  path = separate_vocal(path)
196
 
197
  audios.append(path)
198
  transcripts.append(chunk["text"])
199
-
200
  dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
201
 
202
  current_step += 1
203
  progress((current_step, total_step), desc="Push dataset.")
204
- dataset.push_to_hub(dataset_name, token=oauth_token.token if oauth_token else oauth_token)
205
 
206
-
207
  return html_embed_str, [[transcript] for transcript in transcripts], text
208
 
209
-
210
- def naive_postprocess_whisper_chunks(chunks, audio_array, sampling_rate, stop_chars = ".!:;?", min_duration = 5):
211
- # merge chunks as long as merged audio duration is lower than min_duration and that a stop character is not met
212
- # return list of dictionnaries (text, audio)
213
- # min duration is in seconds
214
  min_duration = int(min_duration * sampling_rate)
215
-
216
-
217
  new_chunks = []
218
  while chunks:
219
  current_chunk = chunks.pop(0)
220
-
221
  begin, end = current_chunk["timestamp"]
222
- begin, end = int(begin*sampling_rate), int(end*sampling_rate)
223
-
224
- current_dur = end-begin
225
-
226
  text = current_chunk["text"]
227
-
228
-
229
  chunk_to_concat = [audio_array[begin:end]]
230
- while chunks and (text[-1] not in stop_chars or (current_dur<min_duration)):
231
  ch = chunks.pop(0)
232
  begin, end = ch["timestamp"]
233
- begin, end = int(begin*sampling_rate), int(end*sampling_rate)
234
- current_dur += end-begin
235
-
236
  text = "".join([text, ch["text"]])
237
-
238
- # TODO: add silence ?
239
  chunk_to_concat.append(audio_array[begin:end])
240
-
241
-
242
  new_chunks.append({
243
- "text": text.strip(),
244
- "audio": np.concatenate(chunk_to_concat),
245
  })
246
- print(f"LENGTH CHUNK #{len(new_chunks)}: {current_dur/sampling_rate}s")
247
-
248
  return new_chunks
249
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  css = """
251
- #intro{
252
- max-width: 100%;
253
- text-align: center;
254
- margin: 0 auto;
 
 
 
255
  }
256
  """
257
- with gr.Blocks(css=css) as demo:
258
- with gr.Row():
259
- gr.LoginButton()
260
- gr.LogoutButton()
261
-
262
- with gr.Tab("YouTube"):
263
- gr.Markdown("Create your own TTS dataset using Youtube", elem_id="intro")
264
- gr.Markdown(
265
- "This demo allows use to create a text-to-speech dataset from an input audio snippet and push it to hub to keep track of it."
266
- f"Demo uses the checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to automatically transcribe audio files"
267
- " of arbitrary length. It then merge chunks of audio and push it to the hub."
268
- )
269
  with gr.Row():
270
  with gr.Column():
271
- audio_youtube = gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
272
- task_youtube = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
273
- cleaning_youtube = gr.Radio(["no-post-processing", "separate-audio"], label="Audio separation and cleaning (takes longer - use it if your samples are not cleaned (background noise and music))", value="separate-audio")
274
- textbox_youtube = gr.Textbox(lines=1, placeholder="Place your new dataset name here. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.", label="Dataset name")
275
-
276
- with gr.Row():
277
- clear_youtube = gr.ClearButton([audio_youtube, task_youtube, cleaning_youtube, textbox_youtube])
278
- submit_youtube = gr.Button("Submit")
279
-
280
  with gr.Column():
281
- html_youtube = gr.HTML()
282
- dataset_youtube = gr.Dataset(label="Transcribed samples.",components=["text"], headers=["Transcripts"], samples=[["transcripts will appear here"]])
283
- transcript_youtube = gr.Textbox(label="Transcription")
284
-
285
- with gr.Tab("Microphone or Audio file"):
286
- gr.Markdown("Create your own TTS dataset using your own recordings", elem_id="intro")
287
- gr.Markdown(
288
- "This demo allows use to create a text-to-speech dataset from an input audio snippet and push it to hub to keep track of it."
289
- f"Demo uses the checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to automatically transcribe audio files"
290
- " of arbitrary length. It then merge chunks of audio and push it to the hub."
291
- )
292
  with gr.Row():
293
  with gr.Column():
294
- audio_file = gr.Audio(type="filepath")
295
- task_file = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
296
- cleaning_file = gr.Radio(["no-post-processing", "separate-audio"], label="Audio separation and cleaning (takes longer - use it if your samples are not cleaned (background noise and music))", value="separate-audio")
297
- textbox_file = gr.Textbox(lines=1, placeholder="Place your new dataset name here. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.", label="Dataset name")
298
-
299
- with gr.Row():
300
- clear_file = gr.ClearButton([audio_file, task_file, cleaning_file, textbox_file])
301
- submit_file = gr.Button("Submit")
302
-
303
  with gr.Column():
304
- dataset_file = gr.Dataset(label="Transcribed samples.", components=["text"], headers=["Transcripts"], samples=[["transcripts will appear here"]])
305
- transcript_file = gr.Textbox(label="Transcription")
306
-
307
-
308
 
309
- submit_file.click(transcribe, inputs=[audio_file, task_file, cleaning_file, textbox_file], outputs=[dataset_file, transcript_file])
310
- submit_youtube.click(yt_transcribe, inputs=[audio_youtube, task_youtube, cleaning_youtube, textbox_youtube], outputs=[html_youtube, dataset_youtube, transcript_youtube])
311
-
312
- demo.launch(debug=True)
 
 
 
 
 
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import pipeline
4
  import yt_dlp as youtube_dl
5
+ import os
6
+ from scipy.io import wavfile
7
  import numpy as np
8
  from datasets import Dataset, Audio
 
 
 
 
 
9
  import tempfile
 
10
  import time
11
+ import demucs
12
 
13
+ MODEL_NAME = "openai/whisper-large-v2"
14
  DEMUCS_MODEL_NAME = "htdemucs_ft"
15
  BATCH_SIZE = 8
16
  FILE_LIMIT_MB = 1000
 
25
  device=device,
26
  )
27
 
28
+ separator = demucs.pretrained.hdemucs()
29
 
30
  def separate_vocal(path):
31
+ origin, separated = separator(path)
32
+ vocal_path = os.path.splitext(path)[0] + "_vocals.wav"
33
+ wavfile.write(vocal_path, separator.samplerate, separated[1].numpy())
34
+ return vocal_path
35
 
36
+ def transcribe(inputs_path, task, use_demucs, dataset_name, oauth_token, progress=gr.Progress()):
 
37
  if inputs_path is None:
38
  raise gr.Error("No audio file submitted! Please upload or record an audio file before submitting your request.")
39
+ if not dataset_name:
40
  raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
 
41
  if oauth_token is None:
42
+ raise gr.Error("No OAuth token submitted! Please login to use this demo.")
43
+
 
44
  total_step = 4
45
  current_step = 0
46
 
47
  current_step += 1
48
  progress((current_step, total_step), desc="Transcribe using Whisper.")
49
+ sampling_rate, inputs = wavfile.read(inputs_path)
 
 
50
  out = pipe(inputs_path, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
 
51
  text = out["text"]
52
 
53
  current_step += 1
 
56
 
57
  current_step += 1
58
  progress((current_step, total_step), desc="Create dataset.")
 
 
59
  transcripts = []
60
  audios = []
61
  with tempfile.TemporaryDirectory() as tmpdirname:
62
+ for i, chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for)")):
 
 
63
  arr = chunk["audio"]
64
  path = os.path.join(tmpdirname, f"{i}.wav")
65
+ wavfile.write(path, sampling_rate, arr)
66
 
67
  if use_demucs == "separate-audio":
 
68
  print(f"Separating vocals #{i}")
69
  path = separate_vocal(path)
70
 
71
  audios.append(path)
72
  transcripts.append(chunk["text"])
73
+
74
  dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
75
 
76
  current_step += 1
77
  progress((current_step, total_step), desc="Push dataset.")
78
+ dataset.push_to_hub(dataset_name, token=oauth_token)
79
 
80
  return [[transcript] for transcript in transcripts], text
81
 
82
+ def yt_transcribe(yt_url, task, use_demucs, dataset_name, oauth_token, progress=gr.Progress()):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  if yt_url is None:
84
+ raise gr.Error("No YouTube URL submitted! Please provide a working link.")
85
+ if not dataset_name:
86
  raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
87
+ if oauth_token is None:
88
+ raise gr.Error("No OAuth token submitted! Please login to use this demo.")
89
 
90
  total_step = 5
91
  current_step = 0
92
 
93
  html_embed_str = _return_yt_html_embed(yt_url)
94
 
 
 
 
 
95
  current_step += 1
96
  progress((current_step, total_step), desc="Load video.")
97
 
98
  with tempfile.TemporaryDirectory() as tmpdirname:
99
  filepath = os.path.join(tmpdirname, "video.mp4")
 
100
  download_yt_audio(yt_url, filepath)
101
+ inputs_path = filepath
 
102
 
103
  inputs = ffmpeg_read(inputs_path, pipe.feature_extractor.sampling_rate)
104
  inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
 
106
  current_step += 1
107
  progress((current_step, total_step), desc="Transcribe using Whisper.")
108
  out = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
 
109
  text = out["text"]
110
 
111
  inputs = ffmpeg_read(inputs_path, dataset_sampling_rate)
 
116
 
117
  current_step += 1
118
  progress((current_step, total_step), desc="Create dataset.")
 
119
  transcripts = []
120
  audios = []
121
  with tempfile.TemporaryDirectory() as tmpdirname:
122
+ for i, chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for).")):
 
 
123
  arr = chunk["audio"]
124
  path = os.path.join(tmpdirname, f"{i}.wav")
125
+ wavfile.write(path, dataset_sampling_rate, arr)
126
 
127
  if use_demucs == "separate-audio":
 
128
  print(f"Separating vocals #{i}")
129
  path = separate_vocal(path)
130
 
131
  audios.append(path)
132
  transcripts.append(chunk["text"])
133
+
134
  dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
135
 
136
  current_step += 1
137
  progress((current_step, total_step), desc="Push dataset.")
138
+ dataset.push_to_hub(dataset_name, token=oauth_token)
139
 
 
140
  return html_embed_str, [[transcript] for transcript in transcripts], text
141
 
142
+ def naive_postprocess_whisper_chunks(chunks, audio_array, sampling_rate, stop_chars=".!:;?", min_duration=5):
 
 
 
 
143
  min_duration = int(min_duration * sampling_rate)
 
 
144
  new_chunks = []
145
  while chunks:
146
  current_chunk = chunks.pop(0)
 
147
  begin, end = current_chunk["timestamp"]
148
+ begin, end = int(begin * sampling_rate), int(end * sampling_rate)
149
+ current_dur = end - begin
 
 
150
  text = current_chunk["text"]
 
 
151
  chunk_to_concat = [audio_array[begin:end]]
152
+ while chunks and (text[-1] not in stop_chars or (current_dur < min_duration)):
153
  ch = chunks.pop(0)
154
  begin, end = ch["timestamp"]
155
+ begin, end = int(begin * sampling_rate), int(end * sampling_rate)
156
+ current_dur += end - begin
 
157
  text = "".join([text, ch["text"]])
 
 
158
  chunk_to_concat.append(audio_array[begin:end])
 
 
159
  new_chunks.append({
160
+ "text": text,
161
+ "audio": np.concatenate(chunk_to_concat)
162
  })
 
 
163
  return new_chunks
164
 
165
+ def _return_yt_html_embed(yt_url):
166
+ video_id = yt_url.split("?v=")[-1]
167
+ HTML_str = (
168
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
169
+ " </center>"
170
+ )
171
+ return HTML_str
172
+
173
+ def download_yt_audio(yt_url, filename):
174
+ info_loader = youtube_dl.YoutubeDL()
175
+ try:
176
+ info = info_loader.extract_info(yt_url, download=False)
177
+ except youtube_dl.utils.DownloadError as err:
178
+ raise gr.Error(str(err))
179
+
180
+ file_length = info["duration_string"]
181
+ file_h_m_s = file_length.split(":")
182
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
183
+
184
+ if len(file_h_m_s) == 1:
185
+ file_h_m_s.insert(0, 0)
186
+ if len(file_h_m_s) == 2:
187
+ file_h_m_s.insert(0, 0)
188
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
189
+
190
+ if file_length_s > YT_LENGTH_LIMIT_S:
191
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
192
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
193
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
194
+
195
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
196
+
197
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
198
+ try:
199
+ ydl.download([yt_url])
200
+ except youtube_dl.utils.ExtractorError as err:
201
+ raise gr.Error(str(err))
202
+
203
  css = """
204
+ #intro {
205
+ padding: 20px;
206
+ background-color: #f0f0f0;
207
+ margin-bottom: 10px;
208
+ }
209
+ #intro h1 {
210
+ font-size: 30px;
211
  }
212
  """
213
+ gr.config.css(css)
214
+
215
+ with gr.Blocks() as demo:
216
+ with gr.Tab("Local file"):
 
 
 
 
 
 
 
 
217
  with gr.Row():
218
  with gr.Column():
219
+ local_audio_input = gr.Audio(type="filepath", label="Upload Audio")
220
+ task_input = gr.Dropdown(choices=["transcribe", "translate"], value="transcribe", label="Task")
221
+ use_demucs_input = gr.Dropdown(choices=["do-nothing", "separate-audio"], value="do-nothing", label="Audio preprocessing")
222
+ dataset_name_input = gr.Textbox(label="Dataset name")
223
+ hf_token = gr.Textbox(label="HuggingFace Token")
224
+ submit_local_button = gr.Button("Transcribe")
 
 
 
225
  with gr.Column():
226
+ local_output_text = gr.Dataframe(label="Transcripts")
227
+ local_output_full_text = gr.Textbox(label="Full Text")
228
+
229
+ submit_local_button.click(
230
+ transcribe,
231
+ inputs=[local_audio_input, task_input, use_demucs_input, dataset_name_input, hf_token],
232
+ outputs=[local_output_text, local_output_full_text],
233
+ )
234
+
235
+ with gr.Tab("YouTube video"):
 
236
  with gr.Row():
237
  with gr.Column():
238
+ yt_url_input = gr.Textbox(label="YouTube URL")
239
+ yt_task_input = gr.Dropdown(choices=["transcribe", "translate"], value="transcribe", label="Task")
240
+ yt_use_demucs_input = gr.Dropdown(choices=["do-nothing", "separate-audio"], value="do-nothing", label="Audio preprocessing")
241
+ yt_dataset_name_input = gr.Textbox(label="Dataset name")
242
+ yt_hf_token = gr.Textbox(label="HuggingFace Token")
243
+ submit_yt_button = gr.Button("Transcribe")
 
 
 
244
  with gr.Column():
245
+ yt_html_embed_str = gr.HTML()
246
+ yt_output_text = gr.Dataframe(label="Transcripts")
247
+ yt_output_full_text = gr.Textbox(label="Full Text")
 
248
 
249
+ submit_yt_button.click(
250
+ yt_transcribe,
251
+ inputs=[yt_url_input, yt_task_input, yt_use_demucs_input, yt_dataset_name_input, yt_hf_token],
252
+ outputs=[yt_html_embed_str, yt_output_text, yt_output_full_text],
253
+ )
254
+
255
+ demo.launch(share=True)