unijoh commited on
Commit
b20a6cf
1 Parent(s): 2859a48

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +11 -55
tts.py CHANGED
@@ -1,62 +1,18 @@
1
- import os
2
  import torch
3
- import sys
4
- import gradio as gr
5
 
6
- from huggingface_hub import hf_hub_download
7
-
8
- # Setup TTS env
9
- if "vits" not in sys.path:
10
- sys.path.append("vits")
11
-
12
- from vits import commons, utils
13
- from vits.models import SynthesizerTrn
14
-
15
- class TextMapper(object):
16
- def __init__(self, vocab_file):
17
- self.symbols = [x.strip() for x in open(vocab_file, encoding="utf-8").readlines()]
18
- self.SPACE_ID = self.symbols.index(" ")
19
- self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
20
-
21
- def text_to_sequence(self, text, cleaner_names):
22
- sequence = [self._symbol_to_id[symbol] for symbol in text.strip()]
23
- return sequence
24
-
25
- def get_text(self, text, hps):
26
- text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
27
- if hps.data.add_blank:
28
- text_norm = commons.intersperse(text_norm, 0)
29
- return torch.LongTensor(text_norm)
30
-
31
- def filter_oov(self, text, lang=None):
32
- val_chars = self._symbol_to_id
33
- return "".join(filter(lambda x: x in val_chars, text))
34
-
35
- def synthesize(text, speed):
36
- if speed is None:
37
- speed = 1.0
38
-
39
- lang_code = "fao"
40
-
41
- vocab_file = hf_hub_download(repo_id="facebook/mms-tts", filename="vocab.txt", subfolder=f"models/{lang_code}")
42
- config_file = hf_hub_download(repo_id="facebook/mms-tts", filename="config.json", subfolder=f"models/{lang_code}")
43
- g_pth = hf_hub_download(repo_id="facebook/mms-tts", filename="G_100000.pth", subfolder=f"models/{lang_code}")
44
 
 
 
 
45
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
46
 
47
- hps = utils.get_hparams_from_file(config_file)
48
- text_mapper = TextMapper(vocab_file)
49
- net_g = SynthesizerTrn(len(text_mapper.symbols), hps.data.filter_length // 2 + 1, hps.train.segment_size // hps.data.hop_length, **hps.model)
50
- net_g.to(device)
51
- net_g.eval()
52
- utils.load_checkpoint(g_pth, net_g, None)
53
-
54
- text = text.lower()
55
- text = text_mapper.filter_oov(text)
56
- stn_tst = text_mapper.get_text(text, hps)
57
  with torch.no_grad():
58
- x_tst = stn_tst.unsqueeze(0).to(device)
59
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
60
- hyp = net_g.infer(x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1.0 / speed)[0][0, 0].cpu().float().numpy()
61
 
62
- return gr.Audio.update(value=(hps.data.sampling_rate, hyp)), text
 
 
1
  import torch
2
+ from transformers import SpeechT5ForTextToSpeech, SpeechT5Processor
 
3
 
4
+ MODEL_ID = "microsoft/speecht5_tts"
5
+ processor = SpeechT5Processor.from_pretrained(MODEL_ID)
6
+ model = SpeechT5ForTextToSpeech.from_pretrained(MODEL_ID)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ def synthesize_speech(text):
9
+ inputs = processor(text, return_tensors="pt")
10
+ # Set device
11
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
12
+ model.to(device)
13
+ inputs = inputs.to(device)
14
 
 
 
 
 
 
 
 
 
 
 
15
  with torch.no_grad():
16
+ speech = model.generate(**inputs)
 
 
17
 
18
+ return processor.decode(speech)