Spaces:
Running
on
T4
Running
on
T4
File size: 5,108 Bytes
4300fed 870f6b3 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 7af4264 4300fed 205284d 3e15339 7af4264 3e15339 205284d 4300fed 7af4264 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import os
import re
import json
import torch
import librosa
import soundfile
import torchaudio
import numpy as np
import torch.nn as nn
from tqdm import tqdm
import torch
from . import utils
from . import commons
from .models import SynthesizerTrn
from .split_utils import split_sentence
from .mel_processing import spectrogram_torch, spectrogram_torch_conv
from .download_utils import load_or_download_config, load_or_download_model
class TTS(nn.Module):
def __init__(self,
language,
device='auto',
use_hf=True,
config_path=None,
ckpt_path=None):
super().__init__()
if device == 'auto':
device = 'cpu'
if torch.cuda.is_available(): device = 'cuda'
if torch.backends.mps.is_available(): device = 'mps'
if 'cuda' in device:
assert torch.cuda.is_available()
# config_path =
hps = load_or_download_config(language, use_hf=use_hf, config_path=config_path)
num_languages = hps.num_languages
num_tones = hps.num_tones
symbols = hps.symbols
model = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
num_tones=num_tones,
num_languages=num_languages,
**hps.model,
).to(device)
model.eval()
self.model = model
self.symbol_to_id = {s: i for i, s in enumerate(symbols)}
self.hps = hps
self.device = device
# load state_dict
checkpoint_dict = load_or_download_model(language, device, use_hf=use_hf, ckpt_path=ckpt_path)
self.model.load_state_dict(checkpoint_dict['model'], strict=True)
language = language.split('_')[0]
self.language = 'ZH_MIX_EN' if language == 'ZH' else language # we support a ZH_MIX_EN model
@staticmethod
def audio_numpy_concat(segment_data_list, sr, speed=1.):
audio_segments = []
for segment_data in segment_data_list:
audio_segments += segment_data.reshape(-1).tolist()
audio_segments += [0] * int((sr * 0.05) / speed)
audio_segments = np.array(audio_segments).astype(np.float32)
return audio_segments
@staticmethod
def split_sentences_into_pieces(text, language, quiet=False):
texts = split_sentence(text, language_str=language)
if not quiet:
print(" > Text split to sentences.")
print('\n'.join(texts))
print(" > ===========================")
return texts
def tts_to_file(self, text, speaker_id, output_path=None, sdp_ratio=0.2, noise_scale=0.6, noise_scale_w=0.8, speed=1.0, pbar=None, format=None, position=None, quiet=False,):
language = self.language
texts = self.split_sentences_into_pieces(text, language, quiet)
audio_list = []
if pbar:
tx = pbar(texts)
else:
if position:
tx = tqdm(texts, position=position)
elif quiet:
tx = texts
else:
tx = tqdm(texts)
for t in tx:
if language in ['EN', 'ZH_MIX_EN']:
t = re.sub(r'([a-z])([A-Z])', r'\1 \2', t)
device = self.device
bert, ja_bert, phones, tones, lang_ids = utils.get_text_for_tts_infer(t, language, self.hps, device, self.symbol_to_id)
with torch.no_grad():
x_tst = phones.to(device).unsqueeze(0)
tones = tones.to(device).unsqueeze(0)
lang_ids = lang_ids.to(device).unsqueeze(0)
bert = bert.to(device).unsqueeze(0)
ja_bert = ja_bert.to(device).unsqueeze(0)
x_tst_lengths = torch.LongTensor([phones.size(0)]).to(device)
del phones
speakers = torch.LongTensor([speaker_id]).to(device)
audio = self.model.infer(
x_tst,
x_tst_lengths,
speakers,
tones,
lang_ids,
bert,
ja_bert,
sdp_ratio=sdp_ratio,
noise_scale=noise_scale,
noise_scale_w=noise_scale_w,
length_scale=1. / speed,
)[0][0, 0].data.cpu().float().numpy()
del x_tst, tones, lang_ids, bert, ja_bert, x_tst_lengths, speakers
#
audio_list.append(audio)
torch.cuda.empty_cache()
audio = self.audio_numpy_concat(audio_list, sr=self.hps.data.sampling_rate, speed=speed)
if output_path is None:
return audio
else:
if format:
soundfile.write(output_path, audio, self.hps.data.sampling_rate, format=format)
else:
soundfile.write(output_path, audio, self.hps.data.sampling_rate)
|