Remove intelligibility refinement
Browse files- api.py +0 -26
- requirements.txt +1 -2
api.py
CHANGED
@@ -5,9 +5,7 @@ from urllib import request
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
-
import torchaudio
|
9 |
import progressbar
|
10 |
-
import ocotillo
|
11 |
|
12 |
from models.diffusion_decoder import DiffusionTts
|
13 |
from models.autoregressive import UnifiedVoice
|
@@ -262,27 +260,3 @@ class TextToSpeech:
|
|
262 |
if len(wav_candidates) > 1:
|
263 |
return wav_candidates
|
264 |
return wav_candidates[0]
|
265 |
-
|
266 |
-
def refine_for_intellibility(self, wav_candidates, corresponding_codes, output_path):
|
267 |
-
"""
|
268 |
-
Further refine the remaining candidates using a ASR model to pick out the ones that are the most understandable.
|
269 |
-
TODO: finish this function
|
270 |
-
:param wav_candidates:
|
271 |
-
:return:
|
272 |
-
"""
|
273 |
-
transcriber = ocotillo.Transcriber(on_cuda=True)
|
274 |
-
transcriptions = transcriber.transcribe_batch(torch.cat(wav_candidates, dim=0).squeeze(1), 24000)
|
275 |
-
best = 99999999
|
276 |
-
for i, transcription in enumerate(transcriptions):
|
277 |
-
dist = lev_distance(transcription, args.text.lower())
|
278 |
-
if dist < best:
|
279 |
-
best = dist
|
280 |
-
best_codes = corresponding_codes[i].unsqueeze(0)
|
281 |
-
best_wav = wav_candidates[i]
|
282 |
-
del transcriber
|
283 |
-
torchaudio.save(os.path.join(output_path, f'{voice}_poor.wav'), best_wav.squeeze(0).cpu(), 24000)
|
284 |
-
|
285 |
-
# Perform diffusion again with the high-quality diffuser.
|
286 |
-
mel = do_spectrogram_diffusion(diffusion, final_diffuser, best_codes, cond_diffusion, mean=False)
|
287 |
-
wav = vocoder.inference(mel)
|
288 |
-
torchaudio.save(os.path.join(args.output_path, f'{voice}.wav'), wav.squeeze(0).cpu(), 24000)
|
|
|
5 |
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
|
|
8 |
import progressbar
|
|
|
9 |
|
10 |
from models.diffusion_decoder import DiffusionTts
|
11 |
from models.autoregressive import UnifiedVoice
|
|
|
260 |
if len(wav_candidates) > 1:
|
261 |
return wav_candidates
|
262 |
return wav_candidates[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -7,5 +7,4 @@ inflect
|
|
7 |
progressbar
|
8 |
einops
|
9 |
unidecode
|
10 |
-
x-transformers
|
11 |
-
ocotillo
|
|
|
7 |
progressbar
|
8 |
einops
|
9 |
unidecode
|
10 |
+
x-transformers
|
|