|
from tasks.tts.dataset_utils import FastSpeechWordDataset |
|
from tasks.tts.tts_utils import load_data_preprocessor |
|
from vocoders.hifigan import HifiGanGenerator |
|
import os |
|
import librosa |
|
import soundfile as sf |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
from string import punctuation |
|
import torch |
|
from utils.ckpt_utils import load_ckpt |
|
from utils.hparams import set_hparams |
|
from utils.hparams import hparams as hp |
|
|
|
class BaseTTSInfer: |
|
def __init__(self, hparams, device=None): |
|
if device is None: |
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
self.hparams = hparams |
|
self.device = device |
|
self.data_dir = hparams['binary_data_dir'] |
|
self.preprocessor, self.preprocess_args = load_data_preprocessor() |
|
self.ph_encoder, self.word_encoder = self.preprocessor.load_dict(self.data_dir) |
|
self.ds_cls = FastSpeechWordDataset |
|
self.model = self.build_model() |
|
self.model.eval() |
|
self.model.to(self.device) |
|
self.vocoder = self.build_vocoder() |
|
self.vocoder.eval() |
|
self.vocoder.to(self.device) |
|
self.asr_processor, self.asr_model = self.build_asr() |
|
|
|
def build_model(self): |
|
raise NotImplementedError |
|
|
|
def forward_model(self, inp): |
|
raise NotImplementedError |
|
|
|
def build_asr(self): |
|
|
|
processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-base-960h") |
|
model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-base-960h").to(self.device) |
|
return processor, model |
|
|
|
def build_vocoder(self): |
|
base_dir = self.hparams['vocoder_ckpt'] |
|
config_path = f'{base_dir}/config.yaml' |
|
config = set_hparams(config_path, global_hparams=False) |
|
vocoder = HifiGanGenerator(config) |
|
load_ckpt(vocoder, base_dir, 'model_gen') |
|
return vocoder |
|
|
|
def run_vocoder(self, c): |
|
c = c.transpose(2, 1) |
|
y = self.vocoder(c)[:, 0] |
|
return y |
|
|
|
def preprocess_input(self, inp): |
|
raise NotImplementedError |
|
|
|
def input_to_batch(self, item): |
|
raise NotImplementedError |
|
|
|
def postprocess_output(self, output): |
|
return output |
|
|
|
def infer_once(self, inp): |
|
inp = self.preprocess_input(inp) |
|
output = self.forward_model(inp) |
|
output = self.postprocess_output(output) |
|
return output |
|
|
|
@classmethod |
|
def example_run(cls, inp): |
|
from utils.audio import save_wav |
|
|
|
|
|
infer_ins = cls(hp) |
|
out = infer_ins.infer_once(inp) |
|
os.makedirs('infer_out', exist_ok=True) |
|
save_wav(out, f'infer_out/{hp["text"]}.wav', hp['audio_sample_rate']) |
|
print(f'Save at infer_out/{hp["text"]}.wav.') |
|
|
|
def asr(self, file): |
|
sample_rate = self.hparams['audio_sample_rate'] |
|
audio_input, source_sample_rate = sf.read(file) |
|
|
|
|
|
if sample_rate is not None and source_sample_rate != sample_rate: |
|
audio_input = librosa.resample(audio_input, source_sample_rate, sample_rate) |
|
|
|
|
|
input_values = self.asr_processor(audio_input, sampling_rate=sample_rate, return_tensors="pt").input_values |
|
|
|
|
|
logits = self.asr_model(input_values).logits |
|
predicted_ids = torch.argmax(logits, dim=-1) |
|
|
|
|
|
transcription = self.asr_processor.decode(predicted_ids[0]) |
|
transcription = transcription.rstrip(punctuation) |
|
return audio_input, transcription |