|
|
|
|
|
|
|
|
|
|
|
import json |
|
import os |
|
import time |
|
import numpy as np |
|
from tqdm import tqdm |
|
import torch |
|
from torch.utils.data import DataLoader |
|
|
|
from models.svc.base import SVCInference |
|
from models.svc.vits.vits import SynthesizerTrn |
|
|
|
from models.svc.base.svc_dataset import SVCTestDataset, SVCTestCollator |
|
from utils.io import save_audio |
|
from utils.audio_slicer import is_silence |
|
|
|
|
|
class VitsInference(SVCInference): |
|
def __init__(self, args=None, cfg=None, infer_type="from_dataset"): |
|
SVCInference.__init__(self, args, cfg) |
|
|
|
def _build_model(self): |
|
net_g = SynthesizerTrn( |
|
self.cfg.preprocess.n_fft // 2 + 1, |
|
self.cfg.preprocess.segment_size // self.cfg.preprocess.hop_size, |
|
self.cfg, |
|
) |
|
self.model = net_g |
|
return net_g |
|
|
|
def build_save_dir(self, dataset, speaker): |
|
save_dir = os.path.join( |
|
self.args.output_dir, |
|
"svc_am_step-{}_{}".format(self.am_restore_step, self.args.mode), |
|
) |
|
if dataset is not None: |
|
save_dir = os.path.join(save_dir, "data_{}".format(dataset)) |
|
if speaker != -1: |
|
save_dir = os.path.join( |
|
save_dir, |
|
"spk_{}".format(speaker), |
|
) |
|
os.makedirs(save_dir, exist_ok=True) |
|
print("Saving to ", save_dir) |
|
return save_dir |
|
|
|
def _build_dataloader(self): |
|
datasets, collate = self._build_test_dataset() |
|
self.test_dataset = datasets(self.args, self.cfg, self.infer_type) |
|
self.test_collate = collate(self.cfg) |
|
self.test_batch_size = min( |
|
self.cfg.inference.batch_size, len(self.test_dataset.metadata) |
|
) |
|
test_dataloader = DataLoader( |
|
self.test_dataset, |
|
collate_fn=self.test_collate, |
|
num_workers=1, |
|
batch_size=self.test_batch_size, |
|
shuffle=False, |
|
) |
|
return test_dataloader |
|
|
|
@torch.inference_mode() |
|
def inference(self): |
|
res = [] |
|
for i, batch in enumerate(self.test_dataloader): |
|
pred_audio_list = self._inference_each_batch(batch) |
|
for j, wav in enumerate(pred_audio_list): |
|
uid = self.test_dataset.metadata[i * self.test_batch_size + j]["Uid"] |
|
file = os.path.join(self.args.output_dir, f"{uid}.wav") |
|
print(f"Saving {file}") |
|
|
|
wav = wav.numpy(force=True) |
|
save_audio( |
|
file, |
|
wav, |
|
self.cfg.preprocess.sample_rate, |
|
add_silence=False, |
|
turn_up=not is_silence(wav, self.cfg.preprocess.sample_rate), |
|
) |
|
res.append(file) |
|
return res |
|
|
|
def _inference_each_batch(self, batch_data, noise_scale=0.667): |
|
device = self.accelerator.device |
|
pred_res = [] |
|
self.model.eval() |
|
with torch.no_grad(): |
|
|
|
|
|
for k, v in batch_data.items(): |
|
batch_data[k] = v.to(device) |
|
|
|
audios, f0 = self.model.infer(batch_data, noise_scale=noise_scale) |
|
|
|
pred_res.extend(audios) |
|
|
|
return pred_res |
|
|