ovieyra21 commited on
Commit
b7d0d3d
·
verified ·
1 Parent(s): 35d5a14

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +238 -45
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import torch
2
  import gradio as gr
3
- import yt_dlp as youtube_dl
4
  import numpy as np
5
  from datasets import Dataset, Audio
6
  from scipy.io import wavfile
@@ -13,47 +13,240 @@ import os
13
  import time
14
  import demucs.api
15
 
16
- # Your imports and other functions remain the same
17
-
18
- def function_transcribe(audio_file, task_file, cleaning_file, textbox_file, max_filesize=75.0, dataset_sampling_rate = 24000, progress=gr.Progress()):
19
- if isinstance(audio_file, str):
20
- audio_file = open(audio_file, "rb")
21
-
22
- _, extension = os.path.splitext(audio_file.name)
23
- if extension != '.mp3' and extension != '.wav':
24
- raise RuntimeError("Invalid file format. Supported formats are mp3 and wav.")
25
-
26
- if audio_file.size // (1024 * 1024) > FILE_LIMIT_MB:
27
- raise RuntimeError(f"File size exceeds the limit ({extension} file {FILE_LIMIT_MB} MB).")
28
-
29
- task = task_file.lower()
30
- if task != "transcribe" and task != "translate":
31
- raise RuntimeError("Unsupported task. Task must be either 'transcribe' or 'translate'.")
32
-
33
- cleanup = bool(cleaning_file)
34
-
35
- dataset_name = textbox_file.strip().replace("/", "_").replace(" ", "_")
36
-
37
- audio_content = audio_file.read()
38
- audio_array, sample_rate = wavfile.imread(BytesIO(audio_content), "wav")
39
-
40
- chunks = naive_postprocess_whisper_chunks(audio_array, sample_rate, stop_chars=".<>?", min_duration=5)
41
-
42
- texts = whisper_batch_transcribe(chunks, model=MODEL_NAME, device=device, task=task)
43
-
44
- if cleanup:
45
- cleaned_chunks = clean_audio_chunks(chunks, audio_array, sample_rate)
46
- cleaned_texts = whisper_batch_transcribe(cleaned_chunks, model=MODEL_NAME, device=device, task=task)
47
- texts = cleaned_texts
48
-
49
- texts = [t.strip() for t in texts]
50
-
51
- dataset = Dataset.from_dict({"text": texts})
52
- if dataset_name:
53
- dataset.push_to_hub(dataset_name, repo_type="dataset", private=True)
54
-
55
- return dataset, "\n\n".join(texts)
56
-
57
- # Continuing with the rest of the script
58
-
59
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
  import gradio as gr
3
+ import youtube_dl
4
  import numpy as np
5
  from datasets import Dataset, Audio
6
  from scipy.io import wavfile
 
13
  import time
14
  import demucs.api
15
 
16
+ MODEL_NAME = "openai/whisper-large-v3" # "patrickvonplaten/wav2vec2-large-960h-lv60-self-4-gram"
17
+ DEMUCS_MODEL_NAME = "htdemucs_ft"
18
+ BATCH_SIZE = 8
19
+ FILE_LIMIT_MB = 1000
20
+ YT_LENGTH_LIMIT_S = 3600 # limit to 1 hour YouTube files
21
+
22
+ device = 0 if torch.cuda.is_available() else "cpu"
23
+
24
+ pipe = pipeline(
25
+ task="automatic-speech-recognition",
26
+ model=MODEL_NAME,
27
+ chunk_length_s=30,
28
+ device=device,
29
+ )
30
+
31
+ separator = demucs.api.Separator(model=DEMUCS_MODEL_NAME, )
32
+
33
+ def separate_vocal(path):
34
+ origin, separated = separator.separate_audio_file(path)
35
+ demucs.api.save_audio(separated["vocals"], path, samplerate=separator.samplerate)
36
+ return path
37
+
38
+ def _return_yt_html_embed(yt_url):
39
+ video_id = yt_url.split("?v=")[-1]
40
+ HTML_str = (
41
+ f'<center> <iframe width="500" height="320" src="https://www.youtube.com/embed/{video_id}"> </iframe>'
42
+ " </center>"
43
+ )
44
+ return gr.HTML(value=HTML_str)
45
+
46
+ def download_yt_audio(yt_url, filename):
47
+ info_loader = youtube_dl.YoutubeDL()
48
+
49
+ try:
50
+ info = info_loader.extract_info(yt_url, download=False)
51
+ except youtube_dl.utils.DownloadError as err:
52
+ raise gr.Error(str(err))
53
+
54
+ file_length = info["duration_string"]
55
+ file_h_m_s = file_length.split(":")
56
+ file_h_m_s = [int(sub_length) for sub_length in file_h_m_s]
57
+
58
+ if len(file_h_m_s) == 1:
59
+ file_h_m_s.insert(0, 0)
60
+ if len(file_h_m_s) == 2:
61
+ file_h_m_s.insert(0, 0)
62
+ file_length_s = file_h_m_s[0] * 3600 + file_h_m_s[1] * 60 + file_h_m_s[2]
63
+
64
+ if file_length_s > YT_LENGTH_LIMIT_S:
65
+ yt_length_limit_hms = time.strftime("%HH:%MM:%SS", time.gmtime(YT_LENGTH_LIMIT_S))
66
+ file_length_hms = time.strftime("%HH:%MM:%SS", time.gmtime(file_length_s))
67
+ raise gr.Error(f"Maximum YouTube length is {yt_length_limit_hms}, got {file_length_hms} YouTube video.")
68
+
69
+ ydl_opts = {"outtmpl": filename, "format": "worstvideo[ext=mp4]+bestaudio[ext=m4a]/best[ext=mp4]/best"}
70
+
71
+ with youtube_dl.YoutubeDL(ydl_opts) as ydl:
72
+ try:
73
+ ydl.download([yt_url])
74
+ except youtube_dl.utils.ExtractorError as err:
75
+ raise gr.Error(str(err))
76
+
77
+ def yt_transcribe(yt_url, task, use_demucs, dataset_name, oauth_token: gr.OAuthToken | None, max_filesize=75.0, dataset_sampling_rate = 24000,
78
+ progress=gr.Progress()):
79
+
80
+ if yt_url is None:
81
+ raise gr.Error("No youtube link submitted! Please put a working link.")
82
+ if dataset_name is None:
83
+ raise gr.Error("No dataset name submitted! Please submit a dataset name. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.")
84
+
85
+ total_step = 5
86
+ current_step = 0
87
+
88
+ HTML_str = _return_yt_html_embed(yt_url)
89
+
90
+ if oauth_token is None:
91
+ gr.Warning("Make sure to click and login before using this demo.")
92
+ return HTML_str, [["transcripts will appear here"]], ""
93
+
94
+ current_step += 1
95
+ progress((current_step, total_step), desc="Load video.")
96
+
97
+ with tempfile.TemporaryDirectory() as tmpdirname:
98
+ filepath = os.path.join(tmpdirname, "video.mp4")
99
+
100
+ download_yt_audio(yt_url, filepath)
101
+ with open(filepath, "rb") as f:
102
+ inputs_path = f.read()
103
+
104
+ inputs = ffmpeg_read(inputs_path, pipe.feature_extractor.sampling_rate)
105
+ inputs = {"array": inputs, "sampling_rate": pipe.feature_extractor.sampling_rate}
106
+
107
+ current_step += 1
108
+ progress((current_step, total_step), desc="Transcribe using Whisper.")
109
+ out = pipe(inputs, batch_size=BATCH_SIZE, generate_kwargs={"task": task}, return_timestamps=True)
110
+
111
+ text = out["text"]
112
+
113
+ inputs = ffmpeg_read(inputs_path, dataset_sampling_rate)
114
+
115
+ current_step += 1
116
+ progress((current_step, total_step), desc="Merge chunks.")
117
+ chunks = naive_postprocess_whisper_chunks(out["chunks"], inputs, dataset_sampling_rate)
118
+
119
+ current_step += 1
120
+ progress((current_step, total_step), desc="Create dataset.")
121
+
122
+ transcripts = []
123
+ audios = []
124
+ with tempfile.TemporaryDirectory() as tmpdirname:
125
+ for i,chunk in enumerate(progress.tqdm(chunks, desc="Creating dataset (and clean audio if asked for).")):
126
+
127
+ # TODO: make sure 1D or 2D?
128
+ arr = chunk["audio"]
129
+ path = os.path.join(tmpdirname, f"{i}.wav")
130
+ wavfile.write(path, dataset_sampling_rate, arr)
131
+
132
+ if use_demucs == "separate-audio":
133
+ # use demucs tp separate vocals
134
+ print(f"Separating vocals #{i}")
135
+ path = separate_vocal(path)
136
+
137
+ audios.append(path)
138
+ transcripts.append(chunk["text"])
139
+
140
+ dataset = Dataset.from_dict({"audio": audios, "text": transcripts}).cast_column("audio", Audio())
141
+
142
+ current_step += 1
143
+ progress((current_step, total_step), desc="Push dataset.")
144
+ dataset.push_to_hub(dataset_name, token=oauth_token.token if oauth_token else oauth_token)
145
+
146
+
147
+ return HTML_str, [[transcript] for transcript in transcripts], text
148
+
149
+ def naive_postprocess_whisper_chunks(chunks, audio_array, sampling_rate, stop_chars = ".!:;?", min_duration = 5):
150
+ # merge chunks as long as merged audio duration is lower than min_duration and that a stop character is not met
151
+ # return list of dictionnaries (text, audio)
152
+ # min duration is in seconds
153
+ min_duration = int(min_duration * sampling_rate)
154
+
155
+
156
+ new_chunks = []
157
+ while chunks:
158
+ current_chunk = chunks.pop(0)
159
+
160
+ begin, end = current_chunk["timestamp"]
161
+ begin, end = int(begin*sampling_rate), int(end*sampling_rate)
162
+
163
+ current_dur = end-begin
164
+
165
+ text = current_chunk["text"]
166
+
167
+
168
+ chunk_to_concat = [audio_array[begin:end]]
169
+ while chunks and (text[-1] not in stop_chars or (current_dur<min_duration)):
170
+ ch = chunks.pop(0)
171
+ begin, end = ch["timestamp"]
172
+ begin, end = int(begin*sampling_rate), int(end*sampling_rate)
173
+ current_dur += end-begin
174
+
175
+ text = "".join([text, ch["text"]])
176
+
177
+ # TODO: add silence ?
178
+ chunk_to_concat.append(audio_array[begin:end])
179
+
180
+
181
+ new_chunks.append({
182
+ "text": text.strip(),
183
+ "audio": np.concatenate(chunk_to_concat),
184
+ })
185
+ print(f"LENGTH CHUNK #{len(new_chunks)}: {current_dur/sampling_rate}s")
186
+
187
+ return new_chunks
188
+
189
+ css = """
190
+ #intro{
191
+ max-width: 100%;
192
+ text-align: center;
193
+ margin: 0 auto;
194
+ }
195
+ """
196
+ with gr.Blocks(css=css) as demo:
197
+ with gr.Row():
198
+ gr.LoginButton()
199
+ gr.LogoutButton()
200
+
201
+ with gr.Tab("YouTube"):
202
+ gr.Markdown("Create your own TTS dataset using Youtube", elem_id="intro")
203
+ gr.Markdown(
204
+ "This demo allows use to create a text-to-speech dataset from an input audio snippet and push it to hub to keep track of it."
205
+ f"Demo uses the checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to automatically transcribe audio files"
206
+ " of arbitrary length. It then merge chunks of audio and push it to the hub."
207
+ )
208
+ with gr.Row():
209
+ with gr.Column():
210
+ audio_youtube = gr.Textbox(lines=1, placeholder="Paste the URL to a YouTube video here", label="YouTube URL")
211
+ task_youtube = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
212
+ cleaning_youtube = gr.Radio(["no-post-processing", "separate-audio"], label="Audio separation and cleaning (takes longer - use it if your samples are not cleaned (background noise and music))", value="separate-audio")
213
+ textbox_youtube = gr.Textbox(lines=1, placeholder="Place your new dataset name here. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.", label="Dataset name")
214
+
215
+ with gr.Row():
216
+ clear_youtube = gr.ClearButton([audio_youtube, task_youtube, cleaning_youtube, textbox_youtube])
217
+ submit_youtube = gr.Button("Submit")
218
+
219
+ with gr.Column():
220
+ html_youtube = gr.HTML()
221
+ dataset_youtube = gr.Dataset(label="Transcribed samples.", components=["text"], headers=["Transcripts"], samples=[["transcripts will appear here"]])
222
+ transcript_youtube = gr.Textbox(label="Transcription")
223
+
224
+ with gr.Tab("Microphone or Audio file"):
225
+ gr.Markdown("Create your own TTS dataset using your own recordings", elem_id="intro")
226
+ gr.Markdown(
227
+ "This demo allows use to create a text-to-speech dataset from an input audio snippet and push it to hub to keep track of it."
228
+ f"Demo uses the checkpoint [{MODEL_NAME}](https://huggingface.co/{MODEL_NAME}) and 🤗 Transformers to automatically transcribe audio files"
229
+ " of arbitrary length. It then merge chunks of audio and push it to the hub."
230
+ )
231
+ with gr.Row():
232
+ with gr.Column():
233
+ audio_file = gr.Audio(type="filepath")
234
+ task_file = gr.Radio(["transcribe", "translate"], label="Task", value="transcribe")
235
+ cleaning_file = gr.Radio(["no-post-processing", "separate-audio"], label="Audio separation and cleaning (takes longer - use it if your samples are not cleaned (background noise and music))", value="no-post-processing")
236
+ textbox_file = gr.Textbox(lines=1, placeholder="Place your new dataset name here. Should be in the format : <user>/<dataset_name> or <org>/<dataset_name>. Also accepts <dataset_name>, which will default to the namespace of the logged-in user.", label="Dataset name")
237
+
238
+ with gr.Row():
239
+ clear_file = gr.ClearButton([audio_file, task_file, cleaning_file, textbox_file])
240
+ submit_file = gr.Button("Submit")
241
+
242
+ with gr.Column():
243
+ dataset_file = gr.Dataset(label="Transcribed samples.", components=["text"], headers=["Transcripts"], samples=[["transcripts will appear here"]])
244
+ transcript_file = gr.Textbox(label="Transcription")
245
+
246
+
247
+
248
+ submit_file.click(transcribe, inputs=[audio_file, task_file, cleaning_file, textbox_file], outputs=[dataset_file, transcript_file])
249
+ submit_youtube.click(yt_transcribe, inputs=[audio_youtube, task_youtube, cleaning_youtube, textbox_youtube], outputs=[html_youtube, dataset_youtube, transcript_youtube])
250
+
251
+ demo.launch(debug=True)
252
+ Confío en que