|
import gradio as gr |
|
import json |
|
import librosa |
|
import os |
|
import soundfile as sf |
|
import tempfile |
|
import uuid |
|
|
|
import torch |
|
import transformers |
|
|
|
from nemo.collections.asr.models import ASRModel |
|
from nemo.collections.asr.parts.utils.streaming_utils import FrameBatchMultiTaskAED |
|
from nemo.collections.asr.parts.utils.transcribe_utils import get_buffered_pred_feat_multitaskAED |
|
|
|
SAMPLE_RATE = 16000 |
|
MAX_AUDIO_MINUTES = 10 |
|
|
|
model = ASRModel.from_pretrained("nvidia/canary-1b") |
|
model.eval() |
|
|
|
|
|
model.change_decoding_strategy(None) |
|
decoding_cfg = model.cfg.decoding |
|
decoding_cfg.beam.beam_size = 1 |
|
model.change_decoding_strategy(decoding_cfg) |
|
|
|
|
|
model.cfg.preprocessor.dither = 0.0 |
|
model.cfg.preprocessor.pad_to = 0 |
|
|
|
feature_stride = model.cfg.preprocessor['window_stride'] |
|
model_stride_in_secs = feature_stride * 8 |
|
|
|
frame_asr = FrameBatchMultiTaskAED( |
|
asr_model=model, |
|
frame_len=40.0, |
|
total_buffer=40.0, |
|
batch_size=16, |
|
) |
|
|
|
amp_dtype = torch.float16 |
|
|
|
|
|
llm_model = transformers.AutoModelForCausalLM.from_pretrained( |
|
"microsoft/Phi-3-mini-128k-instruct", |
|
device_map="auto", |
|
torch_dtype="auto", |
|
trust_remote_code=True, |
|
) |
|
|
|
generation_args = { |
|
"max_new_tokens": 500, |
|
"return_full_text": True, |
|
"temperature": 0.0, |
|
"do_sample": False, |
|
} |
|
|
|
|
|
tokenizer = transformers.AutoTokenizer.from_pretrained("microsoft/Phi-3-mini-128k-instruct") |
|
|
|
llm_pipe = transformers.pipeline( |
|
"text-generation", |
|
model=llm_model, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
def convert_audio(audio_filepath, tmpdir, utt_id): |
|
""" |
|
Convert all files to monochannel 16 kHz wav files. |
|
Do not convert and raise error if audio too long. |
|
Returns output filename and duration. |
|
""" |
|
|
|
data, sr = librosa.load(audio_filepath, sr=None, mono=True) |
|
|
|
duration = librosa.get_duration(y=data, sr=sr) |
|
|
|
if duration / 60.0 > MAX_AUDIO_MINUTES: |
|
raise gr.Error( |
|
f"This demo can transcribe up to {MAX_AUDIO_MINUTES} minutes of audio. " |
|
"If you wish, you may trim the audio using the Audio viewer in Step 1 " |
|
"(click on the scissors icon to start trimming audio)." |
|
) |
|
|
|
if sr != SAMPLE_RATE: |
|
data = librosa.resample(data, orig_sr=sr, target_sr=SAMPLE_RATE) |
|
|
|
out_filename = os.path.join(tmpdir, utt_id + '.wav') |
|
|
|
|
|
sf.write(out_filename, data, SAMPLE_RATE) |
|
|
|
return out_filename, duration |
|
|
|
|
|
def transcribe(audio_filepath, src_lang, tgt_lang, pnc): |
|
|
|
if audio_filepath is None: |
|
raise gr.Error("Please provide some input audio: either upload an audio file or use the microphone") |
|
|
|
utt_id = uuid.uuid4() |
|
with tempfile.TemporaryDirectory() as tmpdir: |
|
converted_audio_filepath, duration = convert_audio(audio_filepath, tmpdir, str(utt_id)) |
|
|
|
|
|
LANG_LONG_TO_LANG_SHORT = { |
|
"English": "en", |
|
"Spanish": "es", |
|
"French": "fr", |
|
"German": "de", |
|
} |
|
if src_lang not in LANG_LONG_TO_LANG_SHORT.keys(): |
|
raise ValueError(f"src_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") |
|
else: |
|
src_lang = LANG_LONG_TO_LANG_SHORT[src_lang] |
|
|
|
if tgt_lang not in LANG_LONG_TO_LANG_SHORT.keys(): |
|
raise ValueError(f"tgt_lang must be one of {LANG_LONG_TO_LANG_SHORT.keys()}") |
|
else: |
|
tgt_lang = LANG_LONG_TO_LANG_SHORT[tgt_lang] |
|
|
|
|
|
|
|
if src_lang == tgt_lang: |
|
taskname = "asr" |
|
else: |
|
taskname = "s2t_translation" |
|
|
|
|
|
pnc = "yes" if pnc else "no" |
|
|
|
|
|
manifest_data = { |
|
"audio_filepath": converted_audio_filepath, |
|
"source_lang": src_lang, |
|
"target_lang": tgt_lang, |
|
"taskname": taskname, |
|
"pnc": pnc, |
|
"answer": "predict", |
|
"duration": str(duration), |
|
} |
|
|
|
manifest_filepath = os.path.join(tmpdir, f'{utt_id}.json') |
|
|
|
with open(manifest_filepath, 'w') as fout: |
|
line = json.dumps(manifest_data) |
|
fout.write(line + '\n') |
|
|
|
|
|
if duration < 40: |
|
output_text = model.transcribe(manifest_filepath)[0] |
|
else: |
|
with torch.cuda.amp.autocast(dtype=amp_dtype): |
|
with torch.no_grad(): |
|
hyps = get_buffered_pred_feat_multitaskAED( |
|
frame_asr, |
|
model.cfg.preprocessor, |
|
model_stride_in_secs, |
|
model.device, |
|
manifest=manifest_filepath, |
|
filepaths=None, |
|
) |
|
|
|
output_text = hyps[0].text |
|
|
|
return output_text |
|
|
|
|
|
def on_src_or_tgt_lang_change(src_lang_value, tgt_lang_value, pnc_value): |
|
"""Callback function for when src_lang or tgt_lang dropdown menus are changed. |
|
Args: |
|
src_lang_value(string), tgt_lang_value (string), pnc_value(bool) - the current |
|
chosen "values" of each Gradio component |
|
Returns: |
|
src_lang, tgt_lang, pnc - these are the new Gradio components that will be displayed |
|
""" |
|
|
|
if src_lang_value == "English" and tgt_lang_value == "English": |
|
|
|
src_lang = gr.Dropdown( |
|
choices=["English", "Spanish", "French", "German"], |
|
value=src_lang_value, |
|
label="Input audio is spoken in:" |
|
) |
|
tgt_lang = gr.Dropdown( |
|
choices=["English", "Spanish", "French", "German"], |
|
value=tgt_lang_value, |
|
label="Transcribe in language:" |
|
) |
|
elif src_lang_value == "English": |
|
|
|
|
|
|
|
src_lang = gr.Dropdown( |
|
choices=["English", tgt_lang_value], |
|
value=src_lang_value, |
|
label="Input audio is spoken in:" |
|
) |
|
tgt_lang = gr.Dropdown( |
|
choices=["English", "Spanish", "French", "German"], |
|
value=tgt_lang_value, |
|
label="Transcribe in language:" |
|
) |
|
elif tgt_lang_value == "English": |
|
|
|
|
|
|
|
src_lang = gr.Dropdown( |
|
choices=["English", "Spanish", "French", "German"], |
|
value=src_lang_value, |
|
label="Input audio is spoken in:" |
|
) |
|
tgt_lang = gr.Dropdown( |
|
choices=["English", src_lang_value], |
|
value=tgt_lang_value, |
|
label="Transcribe in language:" |
|
) |
|
else: |
|
|
|
|
|
src_lang = gr.Dropdown( |
|
choices=["English", src_lang_value], |
|
value=src_lang_value, |
|
label="Input audio is spoken in:" |
|
) |
|
tgt_lang = gr.Dropdown( |
|
choices=["English", tgt_lang_value], |
|
value=tgt_lang_value, |
|
label="Transcribe in language:" |
|
) |
|
|
|
if src_lang_value == tgt_lang_value: |
|
pnc = gr.Checkbox( |
|
value=pnc_value, |
|
label="Punctuation & Capitalization in transcript?", |
|
interactive=True |
|
) |
|
else: |
|
pnc = gr.Checkbox( |
|
value=True, |
|
label="Punctuation & Capitalization in transcript?", |
|
interactive=False |
|
) |
|
return src_lang, tgt_lang, pnc |
|
|
|
def main(audio_filepath, src_lang, tgt_lang, pnc): |
|
translated = transcribe(audio_filepath, src_lang, tgt_lang, pnc) |
|
answer = llm_pipe(translated, **generation_args) |
|
return answer |
|
|
|
|
|
|
|
with gr.Blocks( |
|
title="MyAlexa", |
|
css=""" |
|
textarea { font-size: 18px;} |
|
#model_output_text_box span { |
|
font-size: 18px; |
|
font-weight: bold; |
|
} |
|
""", |
|
theme=gr.themes.Default(text_size=gr.themes.sizes.text_lg) |
|
) as demo: |
|
|
|
gr.HTML("<h1 style='text-align: center'>MyAlexa</h1>") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.HTML( |
|
"<p>Upload an audio file or record with your microphone.</p>" |
|
) |
|
|
|
audio_file = gr.Audio(sources=["microphone", "upload"], type="filepath") |
|
|
|
gr.HTML("<p>Choose the input and output language.</p>") |
|
|
|
src_lang = gr.Dropdown( |
|
choices=["English", "Spanish", "French", "German"], |
|
value="English", |
|
label="Input audio is spoken in:" |
|
) |
|
|
|
with gr.Column(): |
|
tgt_lang = gr.Dropdown( |
|
choices=["English", "Spanish", "French", "German"], |
|
value="English", |
|
label="Transcribe in language:" |
|
) |
|
pnc = gr.Checkbox( |
|
value=True, |
|
label="Punctuation & Capitalization in transcript?", |
|
) |
|
|
|
with gr.Column(): |
|
|
|
gr.HTML("<p>Run the model.</p>") |
|
|
|
go_button = gr.Button( |
|
value="Run model", |
|
variant="primary", |
|
) |
|
|
|
model_output_text_box = gr.Textbox( |
|
label="Model Output", |
|
elem_id="model_output_text_box", |
|
) |
|
|
|
|
|
go_button.click( |
|
fn=main, |
|
inputs = [audio_file, src_lang, tgt_lang, pnc], |
|
outputs = [model_output_text_box] |
|
) |
|
|
|
|
|
src_lang.change( |
|
fn=on_src_or_tgt_lang_change, |
|
inputs=[src_lang, tgt_lang, pnc], |
|
outputs=[src_lang, tgt_lang, pnc], |
|
) |
|
tgt_lang.change( |
|
fn=on_src_or_tgt_lang_change, |
|
inputs=[src_lang, tgt_lang, pnc], |
|
outputs=[src_lang, tgt_lang, pnc], |
|
) |
|
|
|
|
|
demo.queue() |
|
demo.launch() |