Update tts.py
Browse files
tts.py
CHANGED
@@ -1,62 +1,18 @@
|
|
1 |
-
import os
|
2 |
import torch
|
3 |
-
import
|
4 |
-
import gradio as gr
|
5 |
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
if "vits" not in sys.path:
|
10 |
-
sys.path.append("vits")
|
11 |
-
|
12 |
-
from vits import commons, utils
|
13 |
-
from vits.models import SynthesizerTrn
|
14 |
-
|
15 |
-
class TextMapper(object):
|
16 |
-
def __init__(self, vocab_file):
|
17 |
-
self.symbols = [x.strip() for x in open(vocab_file, encoding="utf-8").readlines()]
|
18 |
-
self.SPACE_ID = self.symbols.index(" ")
|
19 |
-
self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
|
20 |
-
|
21 |
-
def text_to_sequence(self, text, cleaner_names):
|
22 |
-
sequence = [self._symbol_to_id[symbol] for symbol in text.strip()]
|
23 |
-
return sequence
|
24 |
-
|
25 |
-
def get_text(self, text, hps):
|
26 |
-
text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
|
27 |
-
if hps.data.add_blank:
|
28 |
-
text_norm = commons.intersperse(text_norm, 0)
|
29 |
-
return torch.LongTensor(text_norm)
|
30 |
-
|
31 |
-
def filter_oov(self, text, lang=None):
|
32 |
-
val_chars = self._symbol_to_id
|
33 |
-
return "".join(filter(lambda x: x in val_chars, text))
|
34 |
-
|
35 |
-
def synthesize(text, speed):
|
36 |
-
if speed is None:
|
37 |
-
speed = 1.0
|
38 |
-
|
39 |
-
lang_code = "fao"
|
40 |
-
|
41 |
-
vocab_file = hf_hub_download(repo_id="facebook/mms-tts", filename="vocab.txt", subfolder=f"models/{lang_code}")
|
42 |
-
config_file = hf_hub_download(repo_id="facebook/mms-tts", filename="config.json", subfolder=f"models/{lang_code}")
|
43 |
-
g_pth = hf_hub_download(repo_id="facebook/mms-tts", filename="G_100000.pth", subfolder=f"models/{lang_code}")
|
44 |
|
|
|
|
|
|
|
45 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
46 |
|
47 |
-
hps = utils.get_hparams_from_file(config_file)
|
48 |
-
text_mapper = TextMapper(vocab_file)
|
49 |
-
net_g = SynthesizerTrn(len(text_mapper.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
|
50 |
-
net_g.to(device)
|
51 |
-
net_g.eval()
|
52 |
-
utils.load_checkpoint(g_pth, net_g, None)
|
53 |
-
|
54 |
-
text = text.lower()
|
55 |
-
text = text_mapper.filter_oov(text)
|
56 |
-
stn_tst = text_mapper.get_text(text, hps)
|
57 |
with torch.no_grad():
|
58 |
-
|
59 |
-
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
|
60 |
-
hyp = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0 / speed)[0][0, 0].cpu().float().numpy()
|
61 |
|
62 |
-
return
|
|
|
|
|
1 |
import torch
|
2 |
+
from transformers import SpeechT5ForTextToSpeech, SpeechT5Processor
|
|
|
3 |
|
4 |
+
MODEL_ID = "microsoft/speecht5_tts"
|
5 |
+
processor = SpeechT5Processor.from_pretrained(MODEL_ID)
|
6 |
+
model = SpeechT5ForTextToSpeech.from_pretrained(MODEL_ID)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
def synthesize_speech(text):
|
9 |
+
inputs = processor(text, return_tensors="pt")
|
10 |
+
# Set device
|
11 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
12 |
+
model.to(device)
|
13 |
+
inputs = inputs.to(device)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
with torch.no_grad():
|
16 |
+
speech = model.generate(**inputs)
|
|
|
|
|
17 |
|
18 |
+
return processor.decode(speech)
|