nithinraok commited on
Commit
0b03171
β€’
1 Parent(s): 05ddcdd

Create app.py

Browse files

initial version

Files changed (1) hide show
  1. app.py +289 -0
app.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import json
3
+ import librosa
4
+ import os
5
+ import soundfile as sf
6
+ import tempfile
7
+ import uuid
8
+
9
+ from nemo.collections.asr.models import ASRModel
10
+
11
+ SAMPLE_RATE = 16000 # Hz
12
+
13
+ model = ASRModel.from_pretrained("nvidia/canary-1b")
14
+ model.eval()
15
+
16
+
17
+ MAX_AUDIO_SECONDS = 40
18
+
19
+
20
+ def convert_audio(audio_filepath, tmpdir, utt_id):
21
+ """
22
+ Convert all files to monochannel 16 kHz wav files.
23
+ Do not convert and raise error if audio too long.
24
+ Returns output filename and duration.
25
+ """
26
+
27
+ data, sr = librosa.load(audio_filepath)
28
+
29
+ duration = librosa.get_duration(y=data, sr=sr)
30
+
31
+ if duration > MAX_AUDIO_SECONDS:
32
+ raise gr.Error(
33
+ f"This demo can transcribe up to {MAX_AUDIO_SECONDS} seconds of audio."
34
+ )
35
+
36
+ if sr != SAMPLE_RATE:
37
+ data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE)
38
+
39
+ # monochannel
40
+ data = librosa.to_mono(data)
41
+
42
+ out_filename = os.path.join(tmpdir, utt_id + '.wav')
43
+
44
+ # save output audio
45
+ sf.write(out_filename, data, SAMPLE_RATE)
46
+
47
+ return out_filename, duration
48
+
49
+
50
+ def transcribe(audio_filepath, src_lang, tgt_lang, pnc):
51
+
52
+ if audio_filepath is None:
53
+ raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone")
54
+
55
+ utt_id = uuid.uuid4()
56
+ with tempfile.TemporaryDirectory() as tmpdir:
57
+
58
+ converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id))
59
+
60
+ # map src_lang and tgt_lang from long versions to short
61
+ LANG_LONG_TO_LANG_SHORT = {
62
+ "English": "en",
63
+ "Spanish": "es",
64
+ "French": "fr",
65
+ "German": "de",
66
+ }
67
+ if src_lang not in LANG_LONG_TO_LANG_SHORT.keys():
68
+ raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
69
+ else:
70
+ src_lang = LANG_LONG_TO_LANG_SHORT[src_lang]
71
+
72
+ if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys():
73
+ raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}")
74
+ else:
75
+ tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang]
76
+
77
+
78
+ # infer taskname from src_lang and tgt_lang
79
+ if src_lang == tgt_lang:
80
+ taskname = "asr"
81
+ else:
82
+ taskname = "s2t_translation"
83
+
84
+ # update pnc variable to be "yes" or "no"
85
+ pnc = "yes" if pnc else "no"
86
+
87
+ # make manifest file and save
88
+ manifest_data = {
89
+ "audio_filepath": converted_audio_filepath,
90
+ "source_lang": src_lang,
91
+ "target_lang": tgt_lang,
92
+ "taskname": taskname,
93
+ "pnc": pnc,
94
+ "answer": "predict",
95
+ "duration": str(duration),
96
+ }
97
+
98
+ manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json')
99
+
100
+ with open(manifest_filepath, 'w') as fout:
101
+ line = json.dumps(manifest_data)
102
+ fout.write(line + '\n')
103
+
104
+ # call transcribe, passing in manifest filepath
105
+ model_output = model.transcribe(manifest_filepath)
106
+
107
+ return model_output[0]
108
+
109
+ # add logic to make sure dropdown menus only suggest valid combos
110
+ def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value):
111
+ """Callback function for when src_lang or tgt_lang dropdown menus are changed.
112
+
113
+ Args:
114
+ src_lang_value(string), tgt_lang_value (string), pnc_value(bool) - the current
115
+ chosen "values" of each Gradio component
116
+ Returns:
117
+ src_lang, tgt_lang, pnc - these are the new Gradio components that will be displayed
118
+
119
+ Note: I found the required logic is easier to understand if you think about the possible src & tgt langs as
120
+ a matrix, e.g. with English, Spanish, French, German as the langs, and only transcription in the same language,
121
+ and X -> English and English -> X translation being allowed, the matrix looks like the diagram below ("Y" means it is
122
+ allowed to go into that state).
123
+ It is easier to understand the code if you think about which state you are in, given the current src_lang_value and
124
+ tgt_lang_value, and then which states you can go to from there.
125
+
126
+ tgt lang
127
+ - |EN |ES |FR |DE
128
+ ------------------
129
+ EN| Y | Y | Y | Y
130
+ ------------------
131
+ src ES| Y | Y | |
132
+ lang ------------------
133
+ FR| Y | | Y |
134
+ ------------------
135
+ DE| Y | | | Y
136
+ """
137
+
138
+ if src_lang_value == "English" and tgt_lang_value == "English":
139
+ # src_lang and tgt_lang can go anywhere
140
+ src_lang = gr.Dropdown(
141
+ choices=["English", "Spanish", "French", "German"],
142
+ value=src_lang_value,
143
+ label="Input audio is spoken in:"
144
+ )
145
+ tgt_lang = gr.Dropdown(
146
+ choices=["English", "Spanish", "French", "German"],
147
+ value=tgt_lang_value,
148
+ label="Transcribe in language:"
149
+ )
150
+ elif src_lang_value == "English":
151
+ # src is English & tgt is non-English
152
+ # => src can only be English or current tgt_lang_values
153
+ # & tgt can be anything
154
+ src_lang = gr.Dropdown(
155
+ choices=["English", tgt_lang_value],
156
+ value=src_lang_value,
157
+ label="Input audio is spoken in:"
158
+ )
159
+ tgt_lang = gr.Dropdown(
160
+ choices=["English", "Spanish", "French", "German"],
161
+ value=tgt_lang_value,
162
+ label="Transcribe in language:"
163
+ )
164
+ elif tgt_lang_value == "English":
165
+ # src is non-English & tgt is English
166
+ # => src can be anything
167
+ # & tgt can only be English or current src_lang_value
168
+ src_lang = gr.Dropdown(
169
+ choices=["English", "Spanish", "French", "German"],
170
+ value=src_lang_value,
171
+ label="Input audio is spoken in:"
172
+ )
173
+ tgt_lang = gr.Dropdown(
174
+ choices=["English", src_lang_value],
175
+ value=tgt_lang_value,
176
+ label="Transcribe in language:"
177
+ )
178
+ else:
179
+ # both src and tgt are non-English
180
+ # => both src and tgt can only be switch to English or themselves
181
+ src_lang = gr.Dropdown(
182
+ choices=["English", src_lang_value],
183
+ value=src_lang_value,
184
+ label="Input audio is spoken in:"
185
+ )
186
+ tgt_lang = gr.Dropdown(
187
+ choices=["English", tgt_lang_value],
188
+ value=tgt_lang_value,
189
+ label="Transcribe in language:"
190
+ )
191
+ # let pnc be anything if src_lang_value == tgt_lang_value, else fix to True
192
+ if src_lang_value == tgt_lang_value:
193
+ pnc = gr.Checkbox(
194
+ value=pnc_value,
195
+ label="Punctuation & Capitalization in transcript?",
196
+ interactive=True
197
+ )
198
+ else:
199
+ pnc = gr.Checkbox(
200
+ value=True,
201
+ label="Punctuation & Capitalization in transcript?",
202
+ interactive=False
203
+ )
204
+ return src_lang, tgt_lang, pnc
205
+
206
+
207
+ with gr.Blocks(
208
+ title="NeMo Canary Model",
209
+ css="""
210
+ textarea { font-size: 18px;}
211
+ #model_output_text_box span {
212
+ font-size: 18px;
213
+ font-weight: bold;
214
+ }
215
+ """,
216
+ theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) # make text slightly bigger (default is text_md )
217
+ ) as demo:
218
+
219
+ gr.HTML("<h1 style='text-align: center'>NeMo Canary model: Transcribe & Translate audio</h1>")
220
+
221
+ with gr.Row():
222
+ with gr.Column():
223
+ gr.HTML("<p><b>Step 1:</b> Upload an audio file or record with your microphone.</p>")
224
+
225
+ audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath")
226
+
227
+ gr.HTML("<p><b>Step 2:</b> Choose the input and output language.</p>")
228
+
229
+ src_lang = gr.Dropdown(
230
+ choices=["English", "Spanish", "French", "German"],
231
+ value="English",
232
+ label="Input audio is spoken in:"
233
+ )
234
+
235
+ with gr.Column():
236
+ tgt_lang = gr.Dropdown(
237
+ choices=["English", "Spanish", "French", "German"],
238
+ value="English",
239
+ label="Transcribe in language:"
240
+ )
241
+ pnc = gr.Checkbox(
242
+ value=True,
243
+ label="Punctuation & Capitalization in transcript?",
244
+ )
245
+
246
+ with gr.Column():
247
+
248
+ gr.HTML("<p><b>Step 3:</b> Run the model.</p>")
249
+
250
+ go_button = gr.Button(
251
+ value="Run model",
252
+ variant="primary", # make "primary" so it stands out (default is "secondary")
253
+ )
254
+
255
+ model_output_text_box = gr.Textbox(
256
+ label="Model Output",
257
+ elem_id="model_output_text_box",
258
+ )
259
+
260
+ with gr.Row():
261
+
262
+ gr.HTML(
263
+ "<p style='text-align: center'>"
264
+ "🐀 <a href='#' target='_blank'>Canary model</a> | "
265
+ "πŸ§‘β€πŸ’» <a href='https://github.com/NVIDIA/NeMo' target='_blank'>NeMo Repository</a>"
266
+ "</p>"
267
+ )
268
+
269
+ go_button.click(
270
+ fn=transcribe,
271
+ inputs = [audio_file, src_lang, tgt_lang, pnc],
272
+ outputs = [model_output_text_box]
273
+ )
274
+
275
+ # call on_src_or_tgt_lang_change whenever src_lang or tgt_lang dropdown menus are changed
276
+ src_lang.change(
277
+ fn=on_src_or_tgt_lang_change,
278
+ inputs=[src_lang, tgt_lang, pnc],
279
+ outputs=[src_lang, tgt_lang, pnc],
280
+ )
281
+ tgt_lang.change(
282
+ fn=on_src_or_tgt_lang_change,
283
+ inputs=[src_lang, tgt_lang, pnc],
284
+ outputs=[src_lang, tgt_lang, pnc],
285
+ )
286
+
287
+
288
+ demo.queue()
289
+ demo.launch()