Irpan
asr
1dfec92
raw
history blame
4.35 kB
import torchaudio
import torch
from transformers import (
WhisperProcessor,
AutoProcessor,
AutoModelForSpeechSeq2Seq,
AutoModelForCTC,
Wav2Vec2Processor,
Wav2Vec2ForCTC
)
import numpy as np
import util
# Load processor and model
models_info = {
"OpenAI-Whisper-Uzbek": {
"processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"),
"ctc_model": False,
"arabic_script": False
},
"ixxan/whisper-small-thugy20": {
"processor": AutoProcessor.from_pretrained("ixxan/whisper-small-thugy20"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-thugy20"),
"ctc_model": False,
"arabic_script": False
},
"ixxan/whisper-small-uyghur-common-voice": {
"processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
"model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"),
"ctc_model": False,
"arabic_script": False
},
"Meta-MMS": {
"processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'),
"model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True),
"ctc_model": True,
"arabic_script": True
},
"ixxan/wav2vec2-large-mms-1b-uyghur-latin": {
"processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
"model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'),
"ctc_model": True,
"arabic_script": False
},
}
# def transcribe(audio_data, model_id) -> str:
# if model_id == "Compare All Models":
# return transcribe_all_models(audio_data)
# else:
# return transcribe_with_model(audio_data, model_id)
# def transcribe_all_models(audio_data) -> dict:
# transcriptions = {}
# for model_id in models_info.keys():
# transcriptions[model_id] = transcribe_with_model(audio_data, model_id)
# return transcriptions
def transcribe(audio_data, model_id) -> str:
# Load audio file
if not audio_data:
return "<<ERROR: Empty Audio Input>>"
if isinstance(audio_data, tuple):
# microphone
sampling_rate, audio_input = audio_data
audio_input = (audio_input / 32768.0).astype(np.float32)
elif isinstance(audio_data, str):
# file upload
audio_input, sampling_rate = torchaudio.load(audio_data)
else:
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
model = models_info[model_id]["model"]
processor = models_info[model_id]["processor"]
target_sr = processor.feature_extractor.sampling_rate
ctc_model = models_info[model_id]["ctc_model"]
# Resample if needed
if sampling_rate != target_sr:
resampler = torchaudio.transforms.Resample(sampling_rate, target_sr)
audio_input = resampler(audio_input)
# Preprocess the audio input
inputs = processor(audio_input.squeeze(), sampling_rate=target_sr, return_tensors="pt")
# Move model to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
inputs = {key: val.to(device) for key, val in inputs.items()}
# Generate transcription
with torch.no_grad():
if ctc_model:
logits = model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcription = processor.batch_decode(predicted_ids)[0]
else:
generated_ids = model.generate(inputs["input_features"], max_length=225)
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
if models_info[model_id]["arabic_script"]:
transcription_arabic = transcription
transcription_latin = util.ug_arab_to_latn(transcription)
else: # Latin script output
transcription_arabic = util.ug_latn_to_arab(transcription)
transcription_latin = transcription
return transcription_arabic, transcription_latin