File size: 2,587 Bytes
f6cde70 d29fa84 f6cde70 d29fa84 f6cde70 d29fa84 f6cde70 d42db5f f6cde70 badbaf4 d29fa84 6c7e7fa d29fa84 8e57d14 d29fa84 f0e249a 16f2d30 fe67771 d29fa84 8e57d14 d29fa84 a95ac22 fe67771 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torchaudio
import util
# Model ID and setup
model_id = 'ixxan/wav2vec2-large-mms-1b-uyghur-latin'
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="uig-script_latin")
asr_processor = Wav2Vec2Processor.from_pretrained(model_id)
asr_processor.tokenizer.set_target_lang("uig-script_latin")
# Automatically allocate the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
asr_model = asr_model.to(device)
def asr(audio_data, target_rate = 16000):
# Load and resample user audio
if isinstance(audio_data, tuple):
# microphone
sampling_rate, audio_input = audio_data
audio_input = (audio_input / 32768.0).astype(np.float32)
elif isinstance(audio_data, str):
# file upload
audio_input, sampling_rate = torchaudio.load(audio_data)
else:
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
# # Check audio duration
# duration = audio_input.shape[1] / sampling_rate
# if duration > 10:
# return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip.>>"
# Resample if needed
if sampling_rate != target_rate:
resampler = torchaudio.transforms.Resample(sampling_rate, target_rate)
audio_input = resampler(audio_input)
sampling_rate = target_rate
# Process audio through ASR model
inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt")
inputs = {key: val.to(device) for key, val in inputs.items()}
with torch.no_grad():
logits = asr_model(**inputs).logits
predicted_ids = torch.argmax(logits, dim=-1)
transcript = asr_processor.batch_decode(predicted_ids)[0]
return transcript
def check_pronunciation(input_text, script_choice, user_audio):
# Transcripts from user input audio
transcript_ugLatn_box = asr(user_audio)
transcript_ugArab_box = util.ug_latn_to_arab(transcript_ugLatn_box)
# Get IPA and Pronunciation Feedback
correct_phoneme, user_phoneme, pronunciation_match, pronunciation_score = util.calculate_pronunciation_accuracy(
reference_text = input_text,
output_text = transcript_ugArab_box,
script_choice=script_choice)
print(f"ASR: {transcript_ugLatn_box}")
return transcript_ugArab_box, transcript_ugLatn_box, correct_phoneme, user_phoneme, pronunciation_match, pronunciation_score |