from typing import Dict, List, Text, Any import os import re from transformers import SpeechT5ForTextToSpeech from transformers import SpeechT5Processor from transformers import SpeechT5HifiGan from speechbrain.pretrained import EncoderClassifier import soundfile as sf import torch import numpy as np # set device device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if torch.cuda.is_available(): # set mixed precision dtype dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16 else: dtype = torch.float32 class EndpointHandler(): def __init__(self, path=""): # Load all required models self.model_id = "Oysiyl/speecht5_tts_common_voice_uk" self.spk_model_name = "speechbrain/spkrec-xvect-voxceleb" self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_id, torch_dtype=dtype).to(device) self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts") self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device) self.speaker_model = EncoderClassifier.from_hparams( source=self.spk_model_name, run_opts={"device": device}, savedir=os.path.join("/tmp", self.spk_model_name) ) waveform, samplerate = sf.read("speaker.wav") self.speaker_embeddings = self.create_speaker_embedding(waveform) @staticmethod def remove_special_characters_s(text: Text) -> Text: chars_to_remove_regex = '[\-\…\–\"\“\%\‘\”\�\»\«\„\`\'́]' # remove special characters text = re.sub(chars_to_remove_regex, '', text) text = re.sub("՚", "'", text) text = re.sub("’", "'", text) text = re.sub(r'ы', 'и', text) text = text.lower() return text @staticmethod def cyrillic_to_latin(text: Text) -> Text: replacements = [ ('а', 'a'), ('б', 'b'), ('в', 'v'), ('г', 'h'), ('д', 'd'), ('е', 'e'), ('ж', 'zh'), ('з', 'z'), ('и', 'y'), ('й', 'j'), ('к', 'k'), ('л', 'l'), ('м', 'm'), ('н', 'n'), ('о', 'o'), ('п', 'p'), ('р', 'r'), ('с', 's'), ('т', 't'), ('у', 'u'), ('ф', 'f'), ('х', 'h'), ('ц', 'ts'), ('ч', 'ch'), ('ш', 'sh'), ('щ', 'sch'), ('ь', "'"), ('ю', 'ju'), ('я', 'ja'), ('є', 'je'), ('і', 'i'), ('ї', 'ji'), ('ґ', 'g') ] for src, dst in replacements: text = text.replace(src, dst) return text def create_speaker_embedding(self, waveform: np.ndarray) -> np.ndarray: with torch.no_grad(): speaker_embeddings = self.speaker_model.encode_batch(torch.tensor(waveform)) speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2) if device.type != 'cuda': speaker_embeddings = speaker_embeddings.squeeze().numpy() else: speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy() speaker_embeddings = torch.tensor(speaker_embeddings, dtype=dtype).unsqueeze(0).to(device) return speaker_embeddings def __call__(self, data: Any) -> np.ndarray: """ :param data: A dictionary contains `inputs`. :return: audiofile. """ text = data.pop("inputs", None) # Check if text is not provided if text is None: return {"error": "Please provide a text."} waveform = data.pop("speaker_embeddings", None) # Check if speaker_embeddings is not provided if waveform is None: speaker_embeddings = self.speaker_embeddings else: speaker_embeddings = self.create_speaker_embedding(waveform) # run inference pipeline text = self.remove_special_characters_s(text) text = self.cyrillic_to_latin(text) input_ids = self.processor(text=text, return_tensors="pt")['input_ids'].to(device) spectrogram = self.model.generate_speech(input_ids, speaker_embeddings) with torch.no_grad(): speech = self.vocoder(spectrogram) if device.type != 'cuda': out = speech.numpy() else: out = speech.cpu().numpy() # return output audio in numpy format return out