File size: 4,807 Bytes
960cd20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import torch
from torch import no_grad, LongTensor
import utils
from utils import get_hparams_from_file, lang_dict
from vits import commons
from vits.mel_processing import spectrogram_torch
from vits.text import text_to_sequence
from vits.models import SynthesizerTrn


class VITS:
    def __init__(self, model_path, config, device="cpu", **kwargs):
        self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
        self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
        self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
        self.speakers = getattr(self.hps_ms, 'speakers', ['0'])
        if not isinstance(self.speakers, list):
            self.speakers = [item[0] for item in sorted(list(self.speakers.items()), key=lambda x: x[1])]
        self.bert_embedding = getattr(self.hps_ms.data, 'bert_embedding',
                                      getattr(self.hps_ms.model, 'bert_embedding', False))
        self.hps_ms.model.bert_embedding = self.bert_embedding
        self.text_cleaners = getattr(self.hps_ms.data, 'text_cleaners', [None])[0]
        self.sampling_rate = self.hps_ms.data.sampling_rate
        self.device = device
        self.model_path = model_path

        # load checkpoint
        # self.load_model()

        self.lang = lang_dict.get(self.text_cleaners, ["unknown"])

    def load_model(self):
        self.net_g_ms = SynthesizerTrn(
            self.n_symbols,
            self.hps_ms.data.filter_length // 2 + 1,
            self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
            n_speakers=self.n_speakers,
            **self.hps_ms.model)
        _ = self.net_g_ms.eval()
        utils.load_checkpoint(self.model_path, self.net_g_ms)
        self.net_g_ms.to(self.device)
    
    def release_model(self):
        del self.net_g_ms
        

    def get_cleaned_text(self, text, hps, cleaned=False):
        if cleaned:
            text_norm = text_to_sequence(text, hps.symbols, [])
        else:
            if self.bert_embedding:
                text_norm, char_embed = text_to_sequence(text, hps.symbols, hps.data.text_cleaners,
                                                         bert_embedding=self.bert_embedding)
                text_norm = LongTensor(text_norm)
                return text_norm, char_embed
            else:
                text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
        if hps.data.add_blank:
            text_norm = commons.intersperse(text_norm, 0)
        text_norm = LongTensor(text_norm)
        return text_norm

    def infer(self, text, id, noise, noisew, length, cleaned=False, **kwargs):
        char_embeds = None
        if self.bert_embedding:
            stn_tst, char_embeds = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
        else:
            stn_tst = self.get_cleaned_text(text, self.hps_ms, cleaned=cleaned)
        id = LongTensor([id])

        with no_grad():
            x_tst = stn_tst.unsqueeze(0).to(self.device)
            x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
            x_tst_prosody = torch.FloatTensor(char_embeds).unsqueeze(0).to(
                self.device) if self.bert_embedding else None
            id = id.to(self.device)

            audio = self.net_g_ms.infer(x=x_tst,
                                        x_lengths=x_tst_lengths,
                                        sid=id,
                                        noise_scale=noise,
                                        noise_scale_w=noisew,
                                        length_scale=length,
                                        bert=x_tst_prosody)[0][0, 0].data.float().cpu().numpy()

        torch.cuda.empty_cache()

        return audio

    def voice_conversion(self, audio_path, original_id, target_id):

        audio = utils.load_audio_to_torch(
            audio_path, self.sampling_rate)

        y = audio.unsqueeze(0)

        spec = spectrogram_torch(y, self.hps_ms.data.filter_length,
                                 self.sampling_rate, self.hps_ms.data.hop_length,
                                 self.hps_ms.data.win_length,
                                 center=False)
        spec_lengths = LongTensor([spec.size(-1)])
        sid_src = LongTensor([original_id])

        with no_grad():
            sid_tgt = LongTensor([target_id])
            audio = self.net_g_ms.voice_conversion(spec.to(self.device),
                                                   spec_lengths.to(self.device),
                                                   sid_src=sid_src.to(self.device),
                                                   sid_tgt=sid_tgt.to(self.device))[0][0, 0].data.cpu().float().numpy()

        torch.cuda.empty_cache()

        return audio