|
import torch |
|
from datasets import load_dataset |
|
from transformers import Pipeline, SpeechT5Processor, SpeechT5HifiGan |
|
|
|
|
|
class TTSPipeline(Pipeline): |
|
def __init__(self, *args, vocoder=None, processor=None, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
if vocoder is None: |
|
raise ValueError("Must pass a vocoder to the TTSPipeline.") |
|
|
|
if processor is None: |
|
raise ValueError("Must pass a processor to the TTSPipeline.") |
|
|
|
if isinstance(vocoder, str): |
|
vocoder = SpeechT5HifiGan.from_pretrained(vocoder) |
|
|
|
if isinstance(processor, str): |
|
processor = SpeechT5Processor.from_pretrained(processor) |
|
|
|
self.processor = processor |
|
self.vocoder = vocoder |
|
|
|
def preprocess(self, text, speaker_embeddings=None): |
|
inputs = self.processor(text=text, return_tensors='pt') |
|
|
|
if speaker_embeddings is None: |
|
embeddings_dataset = load_dataset("Matthijs/cmu-arctic-xvectors", split="validation") |
|
speaker_embeddings = torch.tensor(embeddings_dataset[7305]["xvector"]).unsqueeze(0) |
|
|
|
return {'inputs': inputs, 'speaker_embeddings': speaker_embeddings} |
|
|
|
def _forward(self, model_inputs): |
|
inputs = model_inputs['inputs'] |
|
speaker_embeddings = model_inputs['speaker_embeddings'] |
|
|
|
with torch.no_grad(): |
|
speech = self.model.generate_speech(inputs['input_ids'], speaker_embeddings, vocoder=self.vocoder) |
|
|
|
return speech |
|
|
|
def _sanitize_parameters(self, **pipeline_parameters): |
|
return {}, {}, {} |
|
|
|
def postprocess(self, speech): |
|
return speech |
|
|