Spaces:
Running
Running
import torchaudio | |
import torch | |
from transformers import ( | |
WhisperProcessor, | |
AutoProcessor, | |
AutoModelForSpeechSeq2Seq, | |
AutoModelForCTC, | |
Wav2Vec2Processor, | |
Wav2Vec2ForCTC | |
) | |
import numpy as np | |
import util | |
# Load processor and model | |
models_info = { | |
"OpenAI-Whisper": { | |
"processor": WhisperProcessor.from_pretrained("openai/whisper-small", language="uzbek", task="transcribe"), | |
"model": AutoModelForSpeechSeq2Seq.from_pretrained("openai/whisper-small"), | |
"ctc_model": False, | |
"arabic_script": False | |
}, | |
"Meta-MMS": { | |
"processor": AutoProcessor.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic'), | |
"model": AutoModelForCTC.from_pretrained("facebook/mms-1b-all", target_lang='uig-script_arabic', ignore_mismatched_sizes=True), | |
"ctc_model": True, | |
"arabic_script": True | |
}, | |
"Ixxan-FineTuned-Whisper": { | |
"processor": AutoProcessor.from_pretrained("ixxan/whisper-small-uyghur-common-voice"), | |
"model": AutoModelForSpeechSeq2Seq.from_pretrained("ixxan/whisper-small-uyghur-common-voice"), | |
"ctc_model": False, | |
"arabic_script": False | |
}, | |
"Ixxan-FineTuned-MMS": { | |
"processor": Wav2Vec2Processor.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'), | |
"model": Wav2Vec2ForCTC.from_pretrained("ixxan/wav2vec2-large-mms-1b-uyghur-latin", target_lang='uig-script_latin'), | |
"ctc_model": True, | |
"arabic_script": False | |
}, | |
} | |
# def transcribe(audio_data, model_id) -> str: | |
# if model_id == "Compare All Models": | |
# return transcribe_all_models(audio_data) | |
# else: | |
# return transcribe_with_model(audio_data, model_id) | |
# def transcribe_all_models(audio_data) -> dict: | |
# transcriptions = {} | |
# for model_id in models_info.keys(): | |
# transcriptions[model_id] = transcribe_with_model(audio_data, model_id) | |
# return transcriptions | |
def transcribe(audio_data, model_id) -> str: | |
# Load user audio | |
if isinstance(audio_data, tuple): | |
# microphone | |
sampling_rate, audio_input = audio_data | |
audio_input = (audio_input / 32768.0).astype(np.float32) | |
elif isinstance(audio_data, str): | |
# file upload | |
audio_input, sampling_rate = torchaudio.load(audio_data) | |
else: | |
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data)), None | |
# # Check audio duration | |
# duration = audio_input.shape[1] / sampling_rate | |
# if duration > 10: | |
# return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip for faster processing.>>", None | |
model = models_info[model_id]["model"] | |
processor = models_info[model_id]["processor"] | |
target_sr = processor.feature_extractor.sampling_rate | |
ctc_model = models_info[model_id]["ctc_model"] | |
# Resample if needed | |
if sampling_rate != target_sr: | |
resampler = torchaudio.transforms.Resample(sampling_rate, target_sr) | |
audio_input = resampler(audio_input) | |
sampling_rate = target_sr | |
# Preprocess the audio input | |
inputs = processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt") | |
# Move model to GPU if available | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
model.to(device) | |
inputs = {key: val.to(device) for key, val in inputs.items()} | |
# Generate transcription | |
with torch.no_grad(): | |
if ctc_model: | |
logits = model(**inputs).logits | |
predicted_ids = torch.argmax(logits, dim=-1) | |
transcription = processor.batch_decode(predicted_ids)[0] | |
else: | |
generated_ids = model.generate(inputs["input_features"], max_length=225) | |
transcription = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
if models_info[model_id]["arabic_script"]: | |
transcription_arabic = transcription | |
transcription_latin = util.ug_arab_to_latn(transcription) | |
else: # Latin script output | |
transcription_arabic = util.ug_latn_to_arab(transcription) | |
transcription_latin = transcription | |
print(model_id, transcription_arabic, transcription_latin) | |
return transcription_arabic, transcription_latin |