patrickvonplaten's picture
Update app.py
9e02a9f
raw
history blame
No virus
2.14 kB
import gradio as gr
import librosa
from transformers import AutoFeatureExtractor, AutoTokenizer, SpeechEncoderDecoderModel
feature_extractor = AutoFeatureExtractor.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15", use_auth_token="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")
tokenizer = AutoTokenizer.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15", use_auth_token="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", use_fast=False)
model = SpeechEncoderDecoderModel.from_pretrained("facebook/wav2vec2-xls-r-300m-en-to-15", use_auth_token="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")
def process_audio_file(file):
data, sr = librosa.load(file)
if sr != 16000:
data = librosa.resample(data, sr, 16000)
print(data.shape)
input_values = feature_extractor(data, return_tensors="pt").input_values
return input_values
def transcribe(target_language, file):
target_code = target_language.split("(")[-1].split(")")[0]
forced_bos_token_id = MAPPING[target_code]
input_values = process_audio_file(file)
sequences = model.generate(input_values, forced_bos_token_id=forced_bos_token_id)
transcription = tokenizer.batch_decode(sequences, skip_special_tokens=True)
return transcription[0]
target_languages = [
"German (de)",
"Turkish (tr)",
"Persian (fa)",
"Swedish (sv)",
"Mongolian (mn)",
"Chinese (zh)",
"Welsh (cy)",
"Catalan (ca)",
"Slovenian (sl)",
"Estonian (et)",
"Indonesian (id)",
"Arabic (ar)",
"Tamil (ta)",
"Latvian (lv)",
"Japanese (ja)",
]
MAPPING = {
"de": 250003,
"tr": 250023,
"fa": 250029,
"sv": 250042,
"mn": 250037,
"zh": 250025,
"cy": 250007,
"ca": 250005,
"sl": 250052,
"et": 250006,
"id": 250032,
"ar": 250001,
"ta": 250044,
"lv": 250017,
"ja": 250012,
}
iface = gr.Interface(
fn=transcribe,
inputs=[
gr.inputs.Audio(source="microphone", type='filepath'),
gr.inputs.Dropdown(target_languages),
],
outputs="text",
layout="horizontal",
theme="huggingface",
)
iface.launch()