Spaces:
Running
Running
alessandro trinca tornidor
feat: refactor get_resulting_string to reduce its complexity, improve some return typing
0746ae5
import time | |
from string import punctuation | |
import epitran | |
import numpy as np | |
import torch | |
from . import WordMatching as wm | |
from . import WordMetrics | |
from . import app_logger | |
from .models import AIModels, ModelInterfaces as mi, RuleBasedModels, models as mo | |
def getTrainer(language: str): | |
device = torch.device('cpu') | |
model, decoder = mo.getASRModel(language) | |
model = model.to(device) | |
model.eval() | |
asr_model = AIModels.NeuralASR(model, decoder) | |
if language == 'de': | |
epitran_deu_latn = epitran.Epitran('deu-Latn') | |
phonem_converter = RuleBasedModels.EpitranPhonemConverter(epitran_deu_latn) | |
elif language == 'en': | |
phonem_converter = RuleBasedModels.EngPhonemConverter() | |
else: | |
raise ValueError('Language not implemented') | |
trainer = PronunciationTrainer(asr_model, phonem_converter) | |
return trainer | |
class PronunciationTrainer: | |
current_transcript: str | |
current_ipa: str | |
current_recorded_audio: torch.Tensor | |
current_recorded_transcript: str | |
current_recorded_word_locations: list | |
current_recorded_intonations: torch.tensor | |
current_words_pronunciation_accuracy = [] | |
categories_thresholds = np.array([80, 60, 59]) | |
sampling_rate = 16000 | |
def __init__(self, asr_model: mi.IASRModel, word_to_ipa_coverter: mi.ITextToPhonemModel) -> None: | |
self.asr_model = asr_model | |
self.ipa_converter = word_to_ipa_coverter | |
def getTranscriptAndWordsLocations(self, audio_length_in_samples: int): | |
audio_transcript = self.asr_model.getTranscript() | |
word_locations_in_samples = self.asr_model.getWordLocations() | |
fade_duration_in_samples = 0.05*self.sampling_rate | |
word_locations_in_samples = [(int(np.maximum(0, word['start_ts']-fade_duration_in_samples)), int(np.minimum( | |
audio_length_in_samples-1, word['end_ts']+fade_duration_in_samples))) for word in word_locations_in_samples] | |
return audio_transcript, word_locations_in_samples | |
def getWordsRelativeIntonation(self, Audio: torch.tensor, word_locations: list): | |
intonations = torch.zeros((len(word_locations), 1)) | |
intonation_fade_samples = 0.3*self.sampling_rate | |
app_logger.info(intonations.shape) | |
for word in range(len(word_locations)): | |
intonation_start = int(np.maximum( | |
0, word_locations[word][0]-intonation_fade_samples)) | |
intonation_end = int(np.minimum( | |
Audio.shape[1]-1, word_locations[word][1]+intonation_fade_samples)) | |
intonations[word] = torch.sqrt(torch.mean( | |
Audio[0][intonation_start:intonation_end]**2)) | |
intonations = intonations/torch.mean(intonations) | |
return intonations | |
##################### ASR Functions ########################### | |
def processAudioForGivenText(self, recordedAudio: torch.Tensor = None, real_text=None): | |
start = time.time() | |
app_logger.info(f'starting getAudioTranscript...') | |
recording_transcript, recording_ipa, word_locations = self.getAudioTranscript(recordedAudio) | |
duration = time.time() - start | |
app_logger.info(f'Time for NN to transcript audio: {duration}.') | |
start = time.time() | |
real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices = self.matchSampleAndRecordedWords( | |
real_text, recording_transcript) | |
duration = time.time() - start | |
app_logger.info(f'Time for matching transcripts: {duration}.') | |
start_time, end_time = self.getWordLocationsFromRecordInSeconds( | |
word_locations, mapped_words_indices) | |
pronunciation_accuracy, current_words_pronunciation_accuracy = self.getPronunciationAccuracy( | |
real_and_transcribed_words) # _ipa | |
pronunciation_categories = self.getWordsPronunciationCategory( | |
current_words_pronunciation_accuracy) | |
result = {'recording_transcript': recording_transcript, | |
'real_and_transcribed_words': real_and_transcribed_words, | |
'recording_ipa': recording_ipa, 'start_time': start_time, 'end_time': end_time, | |
'real_and_transcribed_words_ipa': real_and_transcribed_words_ipa, 'pronunciation_accuracy': pronunciation_accuracy, | |
'pronunciation_categories': pronunciation_categories} | |
return result | |
def getAudioTranscript(self, recordedAudio: torch.Tensor = None): | |
current_recorded_audio = recordedAudio | |
app_logger.info('starting preprocessAudio...') | |
current_recorded_audio = self.preprocessAudio(current_recorded_audio) | |
app_logger.info('starting processAudio...') | |
self.asr_model.processAudio(current_recorded_audio) | |
app_logger.info('starting getTranscriptAndWordsLocations...') | |
current_recorded_transcript, current_recorded_word_locations = self.getTranscriptAndWordsLocations( | |
current_recorded_audio.shape[1]) | |
app_logger.info('starting convertToPhonem...') | |
current_recorded_ipa = self.ipa_converter.convertToPhonem(current_recorded_transcript) | |
app_logger.info('ok, return audio transcript!') | |
return current_recorded_transcript, current_recorded_ipa, current_recorded_word_locations | |
def getWordLocationsFromRecordInSeconds(self, word_locations, mapped_words_indices) -> tuple[str, str]: | |
start_time = [] | |
end_time = [] | |
for word_idx in range(len(mapped_words_indices)): | |
start_time.append(float(word_locations[mapped_words_indices[word_idx]] | |
[0])/self.sampling_rate) | |
end_time.append(float(word_locations[mapped_words_indices[word_idx]] | |
[1])/self.sampling_rate) | |
return ' '.join([str(time) for time in start_time]), ' '.join([str(time) for time in end_time]) | |
##################### END ASR Functions ########################### | |
##################### Evaluation Functions ########################### | |
def matchSampleAndRecordedWords(self, real_text, recorded_transcript): | |
words_estimated = recorded_transcript.split() | |
if real_text is None: | |
words_real = self.current_transcript[0].split() | |
else: | |
words_real = real_text.split() | |
mapped_words, mapped_words_indices = wm.get_best_mapped_words( | |
words_estimated, words_real) | |
real_and_transcribed_words = [] | |
real_and_transcribed_words_ipa = [] | |
for word_idx in range(len(words_real)): | |
if word_idx >= len(mapped_words)-1: | |
mapped_words.append('-') | |
real_and_transcribed_words.append( | |
(words_real[word_idx], mapped_words[word_idx])) | |
real_and_transcribed_words_ipa.append((self.ipa_converter.convertToPhonem(words_real[word_idx]), | |
self.ipa_converter.convertToPhonem(mapped_words[word_idx]))) | |
return real_and_transcribed_words, real_and_transcribed_words_ipa, mapped_words_indices | |
def getPronunciationAccuracy(self, real_and_transcribed_words_ipa) -> tuple[float, list]: | |
total_mismatches = 0. | |
number_of_phonemes = 0. | |
current_words_pronunciation_accuracy = [] | |
for pair in real_and_transcribed_words_ipa: | |
real_without_punctuation = self.removePunctuation(pair[0]).lower() | |
number_of_word_mismatches = WordMetrics.edit_distance_python( | |
real_without_punctuation, self.removePunctuation(pair[1]).lower()) | |
total_mismatches += number_of_word_mismatches | |
number_of_phonemes_in_word = len(real_without_punctuation) | |
number_of_phonemes += number_of_phonemes_in_word | |
current_words_pronunciation_accuracy.append(float( | |
number_of_phonemes_in_word-number_of_word_mismatches)/number_of_phonemes_in_word*100) | |
percentage_of_correct_pronunciations = ( | |
number_of_phonemes-total_mismatches)/number_of_phonemes*100 | |
return np.round(percentage_of_correct_pronunciations), current_words_pronunciation_accuracy | |
def removePunctuation(self, word: str) -> str: | |
return ''.join([char for char in word if char not in punctuation]) | |
def getWordsPronunciationCategory(self, accuracies) -> list: | |
categories = [] | |
for accuracy in accuracies: | |
categories.append( | |
self.getPronunciationCategoryFromAccuracy(accuracy)) | |
return categories | |
def getPronunciationCategoryFromAccuracy(self, accuracy) -> int: | |
return np.argmin(abs(self.categories_thresholds-accuracy)) | |
def preprocessAudio(self, audio: torch.tensor) -> torch.tensor: | |
audio = audio-torch.mean(audio) | |
audio = audio/torch.max(torch.abs(audio)) | |
return audio | |