Futuresony commited on
Commit
e13bc5b
·
verified ·
1 Parent(s): 5d490e8

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +173 -0
tts.py CHANGED
@@ -1,3 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
 
2
  TTS_EXAMPLES = [
3
  ["I am going to the store.", "eng (English)", 1.0],
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import re
8
+ import tempfile
9
+ import torch
10
+ import sys
11
+ import gradio as gr
12
+ import numpy as np
13
+
14
+ from huggingface_hub import hf_hub_download
15
+
16
+ # Setup TTS env
17
+ if "vits" not in sys.path:
18
+ sys.path.append("vits")
19
+
20
+ from vits import commons, utils
21
+ from vits.models import SynthesizerTrn
22
+
23
+
24
+ TTS_LANGUAGES = {}
25
+ with open(f"data/tts/all_langs.tsv") as f:
26
+ for line in f:
27
+ iso, name = line.split(" ", 1)
28
+ TTS_LANGUAGES[iso.strip()] = name.strip()
29
+
30
+
31
+ class TextMapper(object):
32
+ def __init__(self, vocab_file):
33
+ self.symbols = [
34
+ x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
35
+ ]
36
+ self.SPACE_ID = self.symbols.index(" ")
37
+ self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
38
+ self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
39
+
40
+ def text_to_sequence(self, text, cleaner_names):
41
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
42
+ Args:
43
+ text: string to convert to a sequence
44
+ cleaner_names: names of the cleaner functions to run the text through
45
+ Returns:
46
+ List of integers corresponding to the symbols in the text
47
+ """
48
+ sequence = []
49
+ clean_text = text.strip()
50
+ for symbol in clean_text:
51
+ symbol_id = self._symbol_to_id[symbol]
52
+ sequence += [symbol_id]
53
+ return sequence
54
+
55
+ def uromanize(self, text, uroman_pl):
56
+ iso = "xxx"
57
+ with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
58
+ with open(tf.name, "w") as f:
59
+ f.write("\n".join([text]))
60
+ cmd = f"perl " + uroman_pl
61
+ cmd += f" -l {iso} "
62
+ cmd += f" < {tf.name} > {tf2.name}"
63
+ os.system(cmd)
64
+ outtexts = []
65
+ with open(tf2.name) as f:
66
+ for line in f:
67
+ line = re.sub(r"\s+", " ", line).strip()
68
+ outtexts.append(line)
69
+ outtext = outtexts[0]
70
+ return outtext
71
+
72
+ def get_text(self, text, hps):
73
+ text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
74
+ if hps.data.add_blank:
75
+ text_norm = commons.intersperse(text_norm, 0)
76
+ text_norm = torch.LongTensor(text_norm)
77
+ return text_norm
78
+
79
+ def filter_oov(self, text, lang=None):
80
+ text = self.preprocess_char(text, lang=lang)
81
+ val_chars = self._symbol_to_id
82
+ txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
83
+ return txt_filt
84
+
85
+ def preprocess_char(self, text, lang=None):
86
+ """
87
+ Special treatement of characters in certain languages
88
+ """
89
+ if lang == "ron":
90
+ text = text.replace("ț", "ţ")
91
+ print(f"{lang} (ț -> ţ): {text}")
92
+ return text
93
+
94
+
95
+ def synthesize(text=None, lang=None, speed=None):
96
+ if speed is None:
97
+ speed = 1.0
98
+
99
+ lang_code = lang.split()[0].strip()
100
+
101
+ vocab_file = hf_hub_download(
102
+ repo_id="facebook/mms-tts",
103
+ filename="vocab.txt",
104
+ subfolder=f"models/{lang_code}",
105
+ )
106
+ config_file = hf_hub_download(
107
+ repo_id="facebook/mms-tts",
108
+ filename="config.json",
109
+ subfolder=f"models/{lang_code}",
110
+ )
111
+ g_pth = hf_hub_download(
112
+ repo_id="facebook/mms-tts",
113
+ filename="G_100000.pth",
114
+ subfolder=f"models/{lang_code}",
115
+ )
116
+
117
+ if torch.cuda.is_available():
118
+ device = torch.device("cuda")
119
+ elif (
120
+ hasattr(torch.backends, "mps")
121
+ and torch.backends.mps.is_available()
122
+ and torch.backends.mps.is_built()
123
+ ):
124
+ device = torch.device("mps")
125
+ else:
126
+ device = torch.device("cpu")
127
+
128
+ print(f"Run inference with {device}")
129
+
130
+ assert os.path.isfile(config_file), f"{config_file} doesn't exist"
131
+ hps = utils.get_hparams_from_file(config_file)
132
+ text_mapper = TextMapper(vocab_file)
133
+ net_g = SynthesizerTrn(
134
+ len(text_mapper.symbols),
135
+ hps.data.filter_length // 2 + 1,
136
+ hps.train.segment_size // hps.data.hop_length,
137
+ **hps.model,
138
+ )
139
+ net_g.to(device)
140
+ _ = net_g.eval()
141
+
142
+ _ = utils.load_checkpoint(g_pth, net_g, None)
143
+
144
+ is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
145
+
146
+ if is_uroman:
147
+ uroman_dir = "uroman"
148
+ assert os.path.exists(uroman_dir)
149
+ uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
150
+ text = text_mapper.uromanize(text, uroman_pl)
151
+
152
+ text = text.lower()
153
+ text = text_mapper.filter_oov(text, lang=lang)
154
+ stn_tst = text_mapper.get_text(text, hps)
155
+ with torch.no_grad():
156
+ x_tst = stn_tst.unsqueeze(0).to(device)
157
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
158
+ hyp = (
159
+ net_g.infer(
160
+ x_tst,
161
+ x_tst_lengths,
162
+ noise_scale=0.667,
163
+ noise_scale_w=0.8,
164
+ length_scale=1.0 / speed,
165
+ )[0][0, 0]
166
+ .cpu()
167
+ .float()
168
+ .numpy()
169
+ )
170
+
171
+ hyp = (hyp * 32768).astype(np.int16)
172
+ return (hps.data.sampling_rate, hyp), text
173
+
174
 
175
  TTS_EXAMPLES = [
176
  ["I am going to the store.", "eng (English)", 1.0],