Oysiyl's picture
Update handler.py
093f5cc verified
from typing import Dict, List, Text, Any
import os
import re
from transformers import SpeechT5ForTextToSpeech
from transformers import SpeechT5Processor
from transformers import SpeechT5HifiGan
from speechbrain.pretrained import EncoderClassifier
import soundfile as sf
import torch
import numpy as np
# set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if torch.cuda.is_available():
# set mixed precision dtype
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] == 8 else torch.float16
else:
dtype = torch.float32
class EndpointHandler():
def __init__(self, path=""):
# Load all required models
self.model_id = "Oysiyl/speecht5_tts_common_voice_uk"
self.spk_model_name = "speechbrain/spkrec-xvect-voxceleb"
self.model = SpeechT5ForTextToSpeech.from_pretrained(self.model_id, torch_dtype=dtype).to(device)
self.processor = SpeechT5Processor.from_pretrained("microsoft/speecht5_tts")
self.vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan").to(device)
self.speaker_model = EncoderClassifier.from_hparams(
source=self.spk_model_name,
run_opts={"device": device},
savedir=os.path.join("/tmp", self.spk_model_name)
)
waveform, samplerate = sf.read("speaker.wav")
self.speaker_embeddings = self.create_speaker_embedding(waveform)
@staticmethod
def remove_special_characters_s(text: Text) -> Text:
chars_to_remove_regex = '[\-\…\–\"\“\%\‘\”\�\»\«\„\`\'́]'
# remove special characters
text = re.sub(chars_to_remove_regex, '', text)
text = re.sub("՚", "'", text)
text = re.sub("’", "'", text)
text = re.sub(r'ы', 'и', text)
text = text.lower()
return text
@staticmethod
def cyrillic_to_latin(text: Text) -> Text:
replacements = [
('а', 'a'),
('б', 'b'),
('в', 'v'),
('г', 'h'),
('д', 'd'),
('е', 'e'),
('ж', 'zh'),
('з', 'z'),
('и', 'y'),
('й', 'j'),
('к', 'k'),
('л', 'l'),
('м', 'm'),
('н', 'n'),
('о', 'o'),
('п', 'p'),
('р', 'r'),
('с', 's'),
('т', 't'),
('у', 'u'),
('ф', 'f'),
('х', 'h'),
('ц', 'ts'),
('ч', 'ch'),
('ш', 'sh'),
('щ', 'sch'),
('ь', "'"),
('ю', 'ju'),
('я', 'ja'),
('є', 'je'),
('і', 'i'),
('ї', 'ji'),
('ґ', 'g')
]
for src, dst in replacements:
text = text.replace(src, dst)
return text
def create_speaker_embedding(self, waveform: np.ndarray) -> np.ndarray:
with torch.no_grad():
speaker_embeddings = self.speaker_model.encode_batch(torch.tensor(waveform))
speaker_embeddings = torch.nn.functional.normalize(speaker_embeddings, dim=2)
if device.type != 'cuda':
speaker_embeddings = speaker_embeddings.squeeze().numpy()
else:
speaker_embeddings = speaker_embeddings.squeeze().cpu().numpy()
speaker_embeddings = torch.tensor(speaker_embeddings, dtype=dtype).unsqueeze(0).to(device)
return speaker_embeddings
def __call__(self, data: Any) -> np.ndarray:
"""
:param data: A dictionary contains `inputs`.
:return: audiofile.
"""
text = data.pop("inputs", None)
# Check if text is not provided
if text is None:
return {"error": "Please provide a text."}
waveform = data.pop("speaker_embeddings", None)
# Check if speaker_embeddings is not provided
if waveform is None:
speaker_embeddings = self.speaker_embeddings
else:
speaker_embeddings = self.create_speaker_embedding(waveform)
# run inference pipeline
text = self.remove_special_characters_s(text)
text = self.cyrillic_to_latin(text)
input_ids = self.processor(text=text, return_tensors="pt")['input_ids'].to(device)
spectrogram = self.model.generate_speech(input_ids, speaker_embeddings)
with torch.no_grad():
speech = self.vocoder(spectrogram)
if device.type != 'cuda':
out = speech.numpy()
else:
out = speech.cpu().numpy()
# return output audio in numpy format
return out