File size: 3,105 Bytes
22871e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import torch
from inference.tts.base_tts_infer import BaseTTSInfer
from modules.tts.syntaspeech.syntaspeech import SyntaSpeech
from utils.commons.ckpt_utils import load_ckpt
from utils.commons.hparams import hparams

from modules.tts.syntaspeech.syntactic_graph_buider import Sentence2GraphParser

class SyntaSpeechInfer(BaseTTSInfer):
    def __init__(self, hparams, device=None):
        super().__init__(hparams, device)
        if hparams['ds_name'] in ['biaobei']:
            self.syntactic_graph_builder = Sentence2GraphParser(language='zh')
        elif hparams['ds_name'] in ['ljspeech', 'libritts']:
            self.syntactic_graph_builder = Sentence2GraphParser(language='en')

    def build_model(self):
        ph_dict_size = len(self.ph_encoder)
        word_dict_size = len(self.word_encoder)
        model = SyntaSpeech(ph_dict_size, word_dict_size, self.hparams)
        load_ckpt(model, hparams['work_dir'], 'model')
        model.to(self.device)
        with torch.no_grad():
            model.store_inverse_all()
        model.eval()
        return model

    def input_to_batch(self, item):
        item_names = [item['item_name']]
        text = [item['text']]
        ph = [item['ph']]
        txt_tokens = torch.LongTensor(item['ph_token'])[None, :].to(self.device)
        txt_lengths = torch.LongTensor([txt_tokens.shape[1]]).to(self.device)
        word_tokens = torch.LongTensor(item['word_token'])[None, :].to(self.device)
        word_lengths = torch.LongTensor([word_tokens.shape[1]]).to(self.device)
        ph2word = torch.LongTensor(item['ph2word'])[None, :].to(self.device)
        spk_ids = torch.LongTensor(item['spk_id'])[None, :].to(self.device)
        dgl_graph, etypes = self.syntactic_graph_builder.parse(item['text'], words=item['words'].split(" "), ph_words=item['ph_words'].split(" "))
        dgl_graph = dgl_graph.to(self.device)
        etypes = etypes.to(self.device)
        batch = {
            'item_name': item_names,
            'text': text,
            'ph': ph,
            'txt_tokens': txt_tokens,
            'txt_lengths': txt_lengths,
            'word_tokens': word_tokens,
            'word_lengths': word_lengths,
            'ph2word': ph2word,
            'spk_ids': spk_ids,
            'graph_lst': [dgl_graph],
            'etypes_lst': [etypes]
        }
        return batch
    def forward_model(self, inp):
        sample = self.input_to_batch(inp)
        with torch.no_grad():
            output = self.model(
                sample['txt_tokens'],
                sample['word_tokens'],
                ph2word=sample['ph2word'],
                word_len=sample['word_lengths'].max(),
                infer=True,
                forward_post_glow=True,
                spk_id=sample.get('spk_ids'),
                graph_lst=sample['graph_lst'],
                etypes_lst=sample['etypes_lst']
            )
            mel_out = output['mel_out']
            wav_out = self.run_vocoder(mel_out)
        wav_out = wav_out.cpu().numpy()
        return wav_out[0]


if __name__ == '__main__':
    SyntaSpeechInfer.example_run()