import logging logging.getLogger('numba').setLevel(logging.WARNING) logging.getLogger('matplotlib').setLevel(logging.WARNING) logging.getLogger('urllib3').setLevel(logging.WARNING) from text import text_to_sequence import numpy as np from scipy.io import wavfile import torch import json import commons import utils import sys import pathlib import onnxruntime as ort import gradio as gr import argparse import time import os from scipy.io.wavfile import write def is_japanese(string): for ch in string: if ord(ch) > 0x3040 and ord(ch) < 0x30FF: return True return False def is_english(string): import re pattern = re.compile('^[A-Za-z0-9.,:;!?()_*"\' ]+$') if pattern.fullmatch(string): return True else: return False def to_numpy(tensor: torch.Tensor): return tensor.detach().cpu().numpy() if tensor.requires_grad \ else tensor.detach().numpy() def get_symbols_from_json(path): assert os.path.isfile(path) with open(path, 'r') as f: data = json.load(f) return data['symbols'] def sle(language,text): text = text.replace('\n','。').replace(' ',',') if language == "中文": tts_input1 = "[ZH]" + text + "[ZH]" return tts_input1 elif language == "自动": tts_input1 = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]" return tts_input1 elif language == "日文": tts_input1 = "[JA]" + text + "[JA]" return tts_input1 elif language == "英文": tts_input1 = "[EN]" + text + "[EN]" return tts_input1 elif language == "手动": return text def get_text(text,hps_ms): text_norm = text_to_sequence(text,hps_ms.symbols,hps_ms.data.text_cleaners) if hps_ms.data.add_blank: text_norm = commons.intersperse(text_norm, 0) text_norm = torch.LongTensor(text_norm) return text_norm def create_tts_fn(ort_sess, speaker_id): def tts_fn(text , language, n_scale= 0.667,n_scale_w = 0.8, l_scale = 1 ): text =sle(language,text) seq = text_to_sequence(text,hps.symbols, cleaner_names=hps.data.text_cleaners) if hps.data.add_blank: seq = commons.intersperse(seq, 0) with torch.no_grad(): x = np.array([seq], dtype=np.int64) x_len = np.array([x.shape[1]], dtype=np.int64) sid = np.array([speaker_id], dtype=np.int64) scales = np.array([n_scale, n_scale_w, l_scale], dtype=np.float32) scales.resize(1, 3) ort_inputs = { 'input': x, 'input_lengths': x_len, 'scales': scales, 'sid': sid } t1 = time.time() audio = np.squeeze(ort_sess.run(None, ort_inputs)) audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6 audio = np.clip(audio, -32767.0, 32767.0) t2 = time.time() spending_time = "推理时间:"+str(t2-t1)+"s" print(spending_time) return (hps.data.sampling_rate, audio) return tts_fn if __name__ == '__main__': symbols = get_symbols_from_json('checkpoints/ShojoKageki/config.json') hps = utils.get_hparams_from_file('checkpoints/ShojoKageki/config.json') device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") models = [] schools = ["ShojoKageki-Nijigasaki","ShojoKageki","Nijigasaki"] lan = ["中文","日文","自动","手动"] with open("checkpoints/info.json", "r", encoding="utf-8") as f: models_info = json.load(f) for i in models_info: school = models_info[i] speakers = school["speakers"] checkpoint = school["checkpoint"] phone_dict = { symbol: i for i, symbol in enumerate(symbols) } ort_sess = ort.InferenceSession(checkpoint) content = [] for j in speakers: sid = int(speakers[j]['sid']) title = school example = speakers[j]['speech'] name = speakers[j]["name"] content.append((sid, name, title, example, create_tts_fn(ort_sess, sid))) models.append(content) with gr.Blocks() as app: gr.Markdown( "#