import os import torch import argparse import numpy as np from scipy.io.wavfile import write import torchaudio import utils from speechsr24k.speechsr import SynthesizerTrn as SpeechSR24 from speechsr48k.speechsr import SynthesizerTrn as SpeechSR48 seed = 1111 torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) def get_param_num(model): num_param = sum(param.numel() for param in model.parameters()) return num_param def SuperResoltuion(a, hierspeech): speechsr = hierspeech os.makedirs(a.output_dir, exist_ok=True) # Prompt load audio, sample_rate = torchaudio.load(a.input_speech) # support only single channel audio = audio[:1,:] # Resampling if sample_rate != 16000: audio = torchaudio.functional.resample(audio, sample_rate, 16000, resampling_method="kaiser_window") file_name = os.path.splitext(os.path.basename(a.input_speech))[0] ## SpeechSR (Optional) (16k Audio --> 24k or 48k Audio) with torch.no_grad(): converted_audio = speechsr(audio.unsqueeze(1).cuda()) converted_audio = converted_audio.squeeze() converted_audio = converted_audio / (torch.abs(converted_audio).max()) * 0.999 * 32767.0 converted_audio = converted_audio.cpu().numpy().astype('int16') file_name2 = "{}.wav".format(file_name) output_file = os.path.join(a.output_dir, file_name2) if a.output_sr == 48000: write(output_file, 48000, converted_audio) else: write(output_file, 24000, converted_audio) def model_load(a): if a.output_sr == 48000: speechsr = SpeechSR48(h_sr48.data.n_mel_channels, h_sr48.train.segment_size // h_sr48.data.hop_length, **h_sr48.model).cuda() utils.load_checkpoint(a.ckpt_sr48, speechsr, None) speechsr.eval() else: # 24000 Hz speechsr = SpeechSR24(h_sr.data.n_mel_channels, h_sr.train.segment_size // h_sr.data.hop_length, **h_sr.model).cuda() utils.load_checkpoint(a.ckpt_sr, speechsr, None) speechsr.eval() return speechsr def inference(a): speechsr = model_load(a) SuperResoltuion(a, speechsr) def main(): print('Initializing Inference Process..') parser = argparse.ArgumentParser() parser.add_argument('--input_speech', default='example/reference_4.wav') parser.add_argument('--output_dir', default='SR_results') parser.add_argument('--ckpt_sr', type=str, default='./speechsr24k/G_340000.pth') parser.add_argument('--ckpt_sr48', type=str, default='./speechsr48k/G_100000.pth') parser.add_argument('--output_sr', type=float, default=48000) a = parser.parse_args() global device, h_sr, h_sr48 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') h_sr = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr)[0], 'config.json') ) h_sr48 = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_sr48)[0], 'config.json') ) inference(a) if __name__ == '__main__': main()