Irpan
commited on
Commit
•
d29fa84
1
Parent(s):
619a599
app
Browse files
app.py
CHANGED
@@ -1,17 +1,7 @@
|
|
1 |
import gradio as gr
|
2 |
import util
|
3 |
import tts
|
4 |
-
|
5 |
-
# Functions
|
6 |
-
def check_pronunciation(input_text, script, user_audio):
|
7 |
-
# Placeholder logic for pronunciation checking
|
8 |
-
transcript_ugArab_box = "Automatic transcription of your audio (Arabic)..."
|
9 |
-
transcript_ugLatn_box = "Automatic transcription of your audio (Latin)..."
|
10 |
-
correct_pronunciation = "Correct pronunciation in IPA"
|
11 |
-
user_pronunciation = "User pronunciation in IPA"
|
12 |
-
pronunciation_match = "Matching segments in green, mismatched in red"
|
13 |
-
pronunciation_score = 85.7 # Replace with actual score calculation
|
14 |
-
return transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation, user_pronunciation, pronunciation_match, pronunciation_score
|
15 |
|
16 |
# Front-End
|
17 |
with gr.Blocks() as app:
|
@@ -101,13 +91,13 @@ with gr.Blocks() as app:
|
|
101 |
)
|
102 |
|
103 |
tts_btn.click(
|
104 |
-
tts.
|
105 |
inputs=[input_text, script_choice],
|
106 |
outputs=[example_audio]
|
107 |
)
|
108 |
|
109 |
check_btn.click(
|
110 |
-
check_pronunciation,
|
111 |
inputs=[input_text, script_choice, user_audio],
|
112 |
outputs=[transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation_box, user_pronunciation_box, match_box, score_box]
|
113 |
)
|
|
|
1 |
import gradio as gr
|
2 |
import util
|
3 |
import tts
|
4 |
+
import asr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Front-End
|
7 |
with gr.Blocks() as app:
|
|
|
91 |
)
|
92 |
|
93 |
tts_btn.click(
|
94 |
+
tts.generate_audio,
|
95 |
inputs=[input_text, script_choice],
|
96 |
outputs=[example_audio]
|
97 |
)
|
98 |
|
99 |
check_btn.click(
|
100 |
+
asr.check_pronunciation,
|
101 |
inputs=[input_text, script_choice, user_audio],
|
102 |
outputs=[transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation_box, user_pronunciation_box, match_box, score_box]
|
103 |
)
|
asr.py
CHANGED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC
|
2 |
+
import torch
|
3 |
+
from umsc import UgMultiScriptConverter
|
4 |
+
import util
|
5 |
+
|
6 |
+
# Model ID and setup
|
7 |
+
model_id = 'ixxan/wav2vec2-large-mms-1b-uyghur-latin'
|
8 |
+
asr_model = Wav2Vec2ForCTC.from_pretrained(model_id, target_lang="uig-script_latin")
|
9 |
+
asr_processor = Wav2Vec2Processor.from_pretrained(model_id)
|
10 |
+
asr_processor.tokenizer.set_target_lang("uig-script_latin")
|
11 |
+
|
12 |
+
# Automatically allocate the device
|
13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
14 |
+
asr_model = asr_model.to(device)
|
15 |
+
|
16 |
+
def asr(user_audio):
|
17 |
+
# Load and resample user audio
|
18 |
+
audio_input, sampling_rate = util.load_and_resample_audio(user_audio, target_rate=16000)
|
19 |
+
|
20 |
+
# Process audio through ASR model
|
21 |
+
inputs = asr_processor(audio_input.squeeze(), sampling_rate=sampling_rate, return_tensors="pt", padding=True)
|
22 |
+
inputs = {key: val.to(device) for key, val in inputs.items()}
|
23 |
+
with torch.no_grad():
|
24 |
+
logits = asr_model(**inputs).logits
|
25 |
+
predicted_ids = torch.argmax(logits, dim=-1)
|
26 |
+
transcript = asr_processor.batch_decode(predicted_ids)[0]
|
27 |
+
return transcript
|
28 |
+
|
29 |
+
|
30 |
+
def check_pronunciation(input_text, script, user_audio):
|
31 |
+
# Transcripts from user input audio
|
32 |
+
transcript_ugLatn_box = asr(user_audio)
|
33 |
+
ug_latn_to_arab = UgMultiScriptConverter('ULS', 'UAS')
|
34 |
+
transcript_ugArab_box = ug_latn_to_arab(transcript_ugLatn_box)
|
35 |
+
|
36 |
+
# Get IPA and Pronunciation Feedback
|
37 |
+
if script == 'Uyghur Latin':
|
38 |
+
input_text = ug_latn_to_arab(input_text) # make sure input text is arabic script to IPA conversion
|
39 |
+
correct_pronunciation, user_pronunciation, pronunciation_match, pronunciation_score = util.calculate_pronunciation_accuracy(
|
40 |
+
reference_text = input_text,
|
41 |
+
output_text = transcript_ugArab_box,
|
42 |
+
language_code='uig-Arab')
|
43 |
+
|
44 |
+
return transcript_ugArab_box, transcript_ugLatn_box, correct_pronunciation, user_pronunciation, pronunciation_match, pronunciation_score
|
tts.py
CHANGED
@@ -2,20 +2,31 @@ from transformers import VitsModel, AutoTokenizer
|
|
2 |
import torch
|
3 |
from umsc import UgMultiScriptConverter
|
4 |
import scipy.io.wavfile
|
5 |
-
import os
|
6 |
|
7 |
-
|
8 |
-
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
ug_latn_to_arab = UgMultiScriptConverter('ULS', 'UAS')
|
13 |
-
if
|
14 |
input_text = ug_latn_to_arab(input_text)
|
15 |
|
16 |
-
|
|
|
|
|
|
|
17 |
with torch.no_grad():
|
18 |
-
tts_output = tts_model(**tts_inputs).waveform
|
19 |
|
20 |
# Save to a temporary file
|
21 |
output_path = "tts_output.wav"
|
|
|
2 |
import torch
|
3 |
from umsc import UgMultiScriptConverter
|
4 |
import scipy.io.wavfile
|
|
|
5 |
|
6 |
+
# Model ID and setup
|
7 |
+
model_id = "facebook/mms-tts-uig-script_arabic"
|
8 |
+
tts_tokenizer = AutoTokenizer.from_pretrained(model_id)
|
9 |
+
tts_model = VitsModel.from_pretrained(model_id)
|
10 |
|
11 |
+
# Automatically allocate the device
|
12 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
13 |
+
tts_model = tts_model.to(device)
|
14 |
+
|
15 |
+
def generate_audio(input_text, script):
|
16 |
+
"""
|
17 |
+
Generate audio for the given input text and script
|
18 |
+
"""
|
19 |
+
# Convert text to Uyghur Arabic if needed
|
20 |
ug_latn_to_arab = UgMultiScriptConverter('ULS', 'UAS')
|
21 |
+
if script != "Uyghur Arabic":
|
22 |
input_text = ug_latn_to_arab(input_text)
|
23 |
|
24 |
+
# Tokenize and move inputs to the same device as the model
|
25 |
+
tts_inputs = tts_tokenizer(input_text, return_tensors="pt").to(device)
|
26 |
+
|
27 |
+
# Perform inference
|
28 |
with torch.no_grad():
|
29 |
+
tts_output = tts_model(**tts_inputs).waveform.cpu() # Move output back to CPU for saving
|
30 |
|
31 |
# Save to a temporary file
|
32 |
output_path = "tts_output.wav"
|
util.py
CHANGED
@@ -1,16 +1,21 @@
|
|
1 |
import random
|
2 |
from umsc import UgMultiScriptConverter
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# Lists of Uyghur short and long texts
|
5 |
short_texts = [
|
6 |
"سالام", "رەھمەت", "ياخشىمۇسىز"
|
7 |
]
|
8 |
long_texts = [
|
9 |
-
"مەكتەپكە بارغاندا تېخىمۇ بىلىملىك
|
10 |
"يېزا مەنزىرىسى ھەقىقەتەن گۈزەل.",
|
11 |
-
"
|
12 |
]
|
13 |
|
|
|
14 |
def generate_short_text(script_choice):
|
15 |
"""Generate a random Uyghur short text based on the type."""
|
16 |
ug_arab_to_latn = UgMultiScriptConverter('UAS', 'ULS')
|
@@ -27,4 +32,66 @@ def generate_long_text(script_choice):
|
|
27 |
text = random.choice(long_texts)
|
28 |
if script_choice == "Uyghur Latin":
|
29 |
return ug_arab_to_latn(text)
|
30 |
-
return text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import random
|
2 |
from umsc import UgMultiScriptConverter
|
3 |
+
import torchaudio
|
4 |
+
import string
|
5 |
+
import epitran
|
6 |
+
from difflib import SequenceMatcher
|
7 |
|
8 |
# Lists of Uyghur short and long texts
|
9 |
short_texts = [
|
10 |
"سالام", "رەھمەت", "ياخشىمۇسىز"
|
11 |
]
|
12 |
long_texts = [
|
13 |
+
"مەكتەپكە بارغاندا تېخىمۇ بىلىملىك بولۇمەن.",
|
14 |
"يېزا مەنزىرىسى ھەقىقەتەن گۈزەل.",
|
15 |
+
"بىزنىڭ ئۆيدەپ تۆت تەكچە تۆتىلىسى تەكتەكچە"
|
16 |
]
|
17 |
|
18 |
+
# Front-End Utils
|
19 |
def generate_short_text(script_choice):
|
20 |
"""Generate a random Uyghur short text based on the type."""
|
21 |
ug_arab_to_latn = UgMultiScriptConverter('UAS', 'ULS')
|
|
|
32 |
text = random.choice(long_texts)
|
33 |
if script_choice == "Uyghur Latin":
|
34 |
return ug_arab_to_latn(text)
|
35 |
+
return text
|
36 |
+
|
37 |
+
# ASR Utils
|
38 |
+
def load_and_resample_audio(file_path, target_rate):
|
39 |
+
"""Load audio and resample based on target sample rate"""
|
40 |
+
audio_input, sampling_rate = torchaudio.load(file_path)
|
41 |
+
if sampling_rate != target_rate:
|
42 |
+
resampler = torchaudio.transforms.Resample(sampling_rate, target_rate)
|
43 |
+
audio_input = resampler(audio_input)
|
44 |
+
return audio_input, target_rate
|
45 |
+
|
46 |
+
def calculate_pronunciation_accuracy(reference_text, output_text, language_code='uig-Arab'):
|
47 |
+
"""
|
48 |
+
Calculate pronunciation accuracy between reference and ASR output text using Epitran.
|
49 |
+
|
50 |
+
Args:
|
51 |
+
reference_text (str): The ground truth text in Uyghur (Arabic script).
|
52 |
+
output_text (str): The ASR output text in Uyghur (Arabic script).
|
53 |
+
language_code (str): Epitran language code (default is 'uig-Arab' for Uyghur).
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
float: Pronunciation accuracy as a percentage.
|
57 |
+
str: IPA transliteration of the reference text.
|
58 |
+
str: IPA transliteration of the output text.
|
59 |
+
"""
|
60 |
+
# Initialize Epitran for Uyghur (Arabic script)
|
61 |
+
ipa_converter = epitran.Epitran(language_code)
|
62 |
+
|
63 |
+
# Remove punctuation from both texts
|
64 |
+
reference_text_clean = remove_punctuation(reference_text)
|
65 |
+
output_text_clean = remove_punctuation(output_text)
|
66 |
+
|
67 |
+
# Transliterate both texts to IPA
|
68 |
+
reference_ipa = ipa_converter.transliterate(reference_text_clean)
|
69 |
+
output_ipa = ipa_converter.transliterate(output_text_clean)
|
70 |
+
|
71 |
+
# Calculate pronunciation accuracy using SequenceMatcher
|
72 |
+
matcher = SequenceMatcher(None, reference_ipa, output_ipa)
|
73 |
+
match_ratio = matcher.ratio() # This is the fraction of matching characters
|
74 |
+
|
75 |
+
# Convert to percentage
|
76 |
+
pronunciation_accuracy = match_ratio * 100
|
77 |
+
|
78 |
+
# Generate HTML for comparison
|
79 |
+
comparison_html = ""
|
80 |
+
for opcode, i1, i2, j1, j2 in matcher.get_opcodes():
|
81 |
+
ref_segment = reference_ipa[i1:i2]
|
82 |
+
out_segment = output_ipa[j1:j2]
|
83 |
+
|
84 |
+
if opcode == 'equal': # Matching characters
|
85 |
+
comparison_html += f'<span style="color: green">{ref_segment}</span>'
|
86 |
+
elif opcode == 'replace': # Mismatched characters
|
87 |
+
comparison_html += f'<span style="color: red">{ref_segment}</span>'
|
88 |
+
elif opcode == 'delete': # Characters in reference but not in output
|
89 |
+
comparison_html += f'<span style="color: red">{ref_segment}</span>'
|
90 |
+
elif opcode == 'insert': # Characters in output but not in reference
|
91 |
+
comparison_html += f'<span style="color: red">{out_segment}</span>'
|
92 |
+
|
93 |
+
return reference_ipa, output_ipa, comparison_html, pronunciation_accuracy
|
94 |
+
|
95 |
+
def remove_punctuation(text):
|
96 |
+
"""Helper function to remove punctuation from text."""
|
97 |
+
return text.translate(str.maketrans('', '', string.punctuation))
|