unijoh commited on
Commit
047c567
1 Parent(s): d501976

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +8 -9
tts.py CHANGED
@@ -1,18 +1,17 @@
1
  import torch
2
- from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech
3
- from datasets import load_dataset
4
- import soundfile as sf
5
 
6
  MODEL_ID = "microsoft/speecht5_tts"
7
  processor = SpeechT5Processor.from_pretrained(MODEL_ID)
8
  model = SpeechT5ForTextToSpeech.from_pretrained(MODEL_ID)
9
- vocoder = torch.hub.load("snakers4/silero-vad", "silero_vad", force_reload=True)
10
 
11
- def synthesize_speech(text_input):
12
- inputs = processor(text=text_input, return_tensors="pt")
 
 
 
13
 
14
  with torch.no_grad():
15
- speech = model.generate_speech(inputs["input_ids"], vocoder=vocoder)
16
 
17
- sf.write("output.wav", speech.numpy(), 16000)
18
- return "output.wav"
 
1
  import torch
2
+ from transformers import SpeechT5ForTextToSpeech, SpeechT5Processor
 
 
3
 
4
  MODEL_ID = "microsoft/speecht5_tts"
5
  processor = SpeechT5Processor.from_pretrained(MODEL_ID)
6
  model = SpeechT5ForTextToSpeech.from_pretrained(MODEL_ID)
 
7
 
8
+ def synthesize_speech(text):
9
+ inputs = processor(text, return_tensors="pt")
10
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
+ model.to(device)
12
+ inputs = inputs.to(device)
13
 
14
  with torch.no_grad():
15
+ speech = model.generate(**inputs)
16
 
17
+ return processor.decode(speech, skip_special_tokens=True)