|
import numpy as np |
|
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC |
|
import torch |
|
import torchaudio |
|
import util |
|
|
|
|
|
model_id = 'ixxan/wav2vec2-large-mms-1b-uyghur-latin' |
|
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="uig-script_latin") |
|
asr_processor = Wav2Vec2Processor.from_pretrained(model_id) |
|
asr_processor.tokenizer.set_target_lang("uig-script_latin") |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
asr_model = asr_model.to(device) |
|
|
|
def asr(audio_data, target_rate = 16000): |
|
|
|
if isinstance(audio_data, tuple): |
|
|
|
sampling_rate, audio_input = audio_data |
|
audio_input = (audio_input / 32768.0).astype(np.float32) |
|
elif isinstance(audio_data, str): |
|
|
|
audio_input, sampling_rate = torchaudio.load(audio_data) |
|
else: |
|
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if sampling_rate != target_rate: |
|
resampler = torchaudio.transforms.Resample(sampling_rate, target_rate) |
|
audio_input = resampler(audio_input) |
|
sampling_rate = target_rate |
|
|
|
|
|
inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt") |
|
inputs = {key: val.to(device) for key, val in inputs.items()} |
|
with torch.no_grad(): |
|
logits = asr_model(**inputs).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
transcript = asr_processor.batch_decode(predicted_ids)[0] |
|
return transcript |
|
|
|
|
|
def check_pronunciation(input_text, script_choice, user_audio): |
|
|
|
transcript_ugLatn_box = asr(user_audio) |
|
transcript_ugArab_box = util.ug_latn_to_arab(transcript_ugLatn_box) |
|
|
|
|
|
correct_phoneme, user_phoneme, pronunciation_match, pronunciation_score = util.calculate_pronunciation_accuracy( |
|
reference_text = input_text, |
|
output_text = transcript_ugArab_box, |
|
script_choice=script_choice) |
|
|
|
print(f"ASR: {transcript_ugLatn_box}") |
|
|
|
return transcript_ugArab_box, transcript_ugLatn_box, correct_phoneme, user_phoneme, pronunciation_match, pronunciation_score |