File size: 2,587 Bytes
f6cde70
d29fa84
 
f6cde70
d29fa84
 
 
 
 
 
 
 
 
 
 
 
f6cde70
d29fa84
f6cde70
 
 
 
 
 
 
 
 
 
d42db5f
 
 
 
 
f6cde70
 
 
 
badbaf4
d29fa84
 
6c7e7fa
d29fa84
 
 
 
 
 
 
 
8e57d14
d29fa84
 
f0e249a
16f2d30
 
fe67771
d29fa84
 
8e57d14
d29fa84
a95ac22
 
fe67771
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import numpy as np
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
import torch
import torchaudio
import util 

# Model ID and setup
model_id = 'ixxan/wav2vec2-large-mms-1b-uyghur-latin'
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="uig-script_latin")
asr_processor = Wav2Vec2Processor.from_pretrained(model_id)
asr_processor.tokenizer.set_target_lang("uig-script_latin")

# Automatically allocate the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
asr_model = asr_model.to(device) 

def asr(audio_data, target_rate = 16000):
    # Load and resample user audio
    if isinstance(audio_data, tuple):
        # microphone
        sampling_rate, audio_input = audio_data
        audio_input = (audio_input / 32768.0).astype(np.float32)
    elif isinstance(audio_data, str):
        # file upload
        audio_input, sampling_rate = torchaudio.load(audio_data)
    else: 
        return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
    
    #  # Check audio duration
    # duration = audio_input.shape[1] / sampling_rate
    # if duration > 10:
    #     return f"<<ERROR: Audio duration ({duration:.2f}s) exceeds 10 seconds. Please upload a shorter audio clip.>>"
    
    # Resample if needed
    if sampling_rate != target_rate:
        resampler = torchaudio.transforms.Resample(sampling_rate, target_rate)
        audio_input = resampler(audio_input)
        sampling_rate = target_rate
    
    # Process audio through ASR model
    inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt")
    inputs = {key: val.to(device) for key, val in inputs.items()}
    with torch.no_grad():
        logits = asr_model(**inputs).logits
        predicted_ids = torch.argmax(logits, dim=-1)
        transcript = asr_processor.batch_decode(predicted_ids)[0]
    return transcript

    
def check_pronunciation(input_text, script_choice, user_audio):
    # Transcripts from user input audio
    transcript_ugLatn_box = asr(user_audio)
    transcript_ugArab_box = util.ug_latn_to_arab(transcript_ugLatn_box)
    
    # Get IPA and Pronunciation Feedback
    correct_phoneme, user_phoneme, pronunciation_match, pronunciation_score = util.calculate_pronunciation_accuracy(
        reference_text = input_text, 
        output_text = transcript_ugArab_box, 
        script_choice=script_choice)
    
    print(f"ASR: {transcript_ugLatn_box}")
    
    return transcript_ugArab_box, transcript_ugLatn_box, correct_phoneme, user_phoneme, pronunciation_match, pronunciation_score