vits-simple-api-gsv / vits /hubert_vits.py
Artrajz's picture
init
960cd20
raw
history blame
No virus
4.39 kB
import librosa
import numpy as np
import torch
from torch import no_grad, LongTensor, inference_mode, FloatTensor
import utils
from utils import get_hparams_from_file, lang_dict
from vits import commons
from vits.text import text_to_sequence
from vits.models import SynthesizerTrn
class HuBert_VITS:
def __init__(self, model_path, config, device=torch.device("cpu"), **kwargs):
self.hps_ms = get_hparams_from_file(config) if isinstance(config, str) else config
self.n_speakers = getattr(self.hps_ms.data, 'n_speakers', 0)
self.n_symbols = len(getattr(self.hps_ms, 'symbols', []))
self.speakers = getattr(self.hps_ms, 'speakers', ['0'])
if not isinstance(self.speakers, list):
self.speakers = [item[0] for item in sorted(list(self.speakers.items()), key=lambda x: x[1])]
self.use_f0 = getattr(self.hps_ms.data, 'use_f0', False)
self.model_path = model_path
self.device = device
key = getattr(self.hps_ms.data, "text_cleaners", ["none"])[0]
self.lang = lang_dict.get(key, ["unknown"])
def load_model(self, hubert):
self.hubert = hubert
self.net_g_ms = SynthesizerTrn(
self.n_symbols,
self.hps_ms.data.filter_length // 2 + 1,
self.hps_ms.train.segment_size // self.hps_ms.data.hop_length,
n_speakers=self.n_speakers,
**self.hps_ms.model)
_ = self.net_g_ms.eval()
utils.load_checkpoint(self.model_path, self.net_g_ms)
self.net_g_ms.to(self.device)
def get_cleaned_text(self, text, hps, cleaned=False):
if cleaned:
text_norm = text_to_sequence(text, hps.symbols, [])
else:
if self.bert_embedding:
text_norm, char_embed = text_to_sequence(text, hps.symbols, hps.data.text_cleaners,
bert_embedding=self.bert_embedding)
text_norm = LongTensor(text_norm)
return text_norm, char_embed
else:
text_norm = text_to_sequence(text, hps.symbols, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm
def get_cleaner(self):
return getattr(self.hps_ms.data, 'text_cleaners', [None])[0]
def get_speakers(self, escape=False):
return self.speakers
@property
def sampling_rate(self):
return self.hps_ms.data.sampling_rate
def infer(self, audio_path, id, noise, noisew, length, f0_scale=1, **kwargs):
if self.use_f0:
audio, sampling_rate = librosa.load(audio_path, sr=self.hps_ms.data.sampling_rate, mono=True)
audio16000 = librosa.resample(audio, orig_sr=sampling_rate, target_sr=16000)
else:
audio16000, sampling_rate = librosa.load(audio_path, sr=16000, mono=True)
with inference_mode():
units = self.hubert.units(FloatTensor(audio16000).unsqueeze(0).unsqueeze(0)).squeeze(0).numpy()
if self.use_f0:
f0 = librosa.pyin(audio,
sr=sampling_rate,
fmin=librosa.note_to_hz('C0'),
fmax=librosa.note_to_hz('C7'),
frame_length=1780)[0]
target_length = len(units[:, 0])
f0 = np.nan_to_num(np.interp(np.arange(0, len(f0) * target_length, len(f0)) / target_length,
np.arange(0, len(f0)), f0)) * f0_scale
units[:, 0] = f0 / 10
stn_tst = FloatTensor(units)
id = LongTensor([id])
with no_grad():
x_tst = stn_tst.unsqueeze(0).to(self.device)
x_tst_lengths = LongTensor([stn_tst.size(0)]).to(self.device)
id = id.to(self.device)
audio = self.net_g_ms.infer(x=x_tst,
x_lengths=x_tst_lengths,
sid=id,
noise_scale=noise,
noise_scale_w=noisew,
length_scale=length)[0][0, 0].data.float().cpu().numpy()
torch.cuda.empty_cache()
return audio