import argparse |
from text import text_to_sequence |
import numpy as np |
from scipy.io import wavfile |
import torch |
import json |
import commons |
import utils |
import sys |
import pathlib |
from flask import Flask, request |
import threading |
import onnxruntime as ort |
import time |
from pydub import AudioSegment |
import io |
import os |
from transformers import AutoTokenizer, AutoModel |
import tkinter as tk |
from tkinter import scrolledtext |
from scipy.io.wavfile import write |
def get_args(): |
parser = argparse.ArgumentParser(description='inference') |
parser.add_argument('--onnx_model', default = './moe/model.onnx') |
parser.add_argument('--cfg', default="./moe/config_v.json") |
parser.add_argument('--outdir', default="./moe", |
help='ouput folder') |
parser.add_argument('--audio', |
type=str, |
help='你要替换的音频文件的,假设这些音频文件为temp1、temp2、temp3......', |
default = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav') |
parser.add_argument('--ChatGLM',default = "./moe", |
help='https://github.com/THUDM/ChatGLM-6B') |
args = parser.parse_args() |
return args |
def to_numpy(tensor: torch.Tensor): |
return tensor.detach().cpu().numpy() if tensor.requires_grad \ |
else tensor.detach().numpy() |
def get_symbols_from_json(path): |
import os |
assert os.path.isfile(path) |
with open(path, 'r') as f: |
data = json.load(f) |
return data['symbols'] |
args = get_args() |
symbols = get_symbols_from_json(args.cfg) |
phone_dict = { |
symbol: i for i, symbol in enumerate(symbols) |
} |
hps = utils.get_hparams_from_file(args.cfg) |
ort_sess = ort.InferenceSession(args.onnx_model) |
def is_japanese(string): |
for ch in string: |
if ord(ch) > 0x3040 and ord(ch) < 0x30FF: |
return True |
return False |
def infer(text): |
sid = 7 |
text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]" |
seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners) |
if hps.data.add_blank: |
seq = commons.intersperse(seq, 0) |
with torch.no_grad(): |
x = np.array([seq], dtype=np.int64) |
x_len = np.array([x.shape[1]], dtype=np.int64) |
sid = np.array([sid], dtype=np.int64) |
scales = np.array([0.667, 0.7, 1], dtype=np.float32) |
scales.resize(1, 3) |
ort_inputs = { |
'input': x, |
'input_lengths': x_len, |
'scales': scales, |
'sid': sid |
} |
t1 = time.time() |
audio = np.squeeze(ort_sess.run(None, ort_inputs)) |
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6 |
audio = np.clip(audio, -32767.0, 32767.0) |
bytes_wav = bytes() |
byte_io = io.BytesIO(bytes_wav) |
wavfile.write(args.audio + '.wav',hps.data.sampling_rate, audio.astype(np.int16)) |
i = 0 |
while i < 19: |
i +=1 |
cmd = 'ffmpeg -y -i ' + args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i)) |
os.system(cmd) |
t2 = time.time() |
print("推理耗时:",(t2 - t1),"s") |
return text |
tokenizer = AutoTokenizer.from_pretrained(args.ChatGLM, trust_remote_code=True) |
model = AutoModel.from_pretrained(args.ChatGLM, trust_remote_code=True).half().quantize(4).cuda() |
history = [] |
def send_message(): |
global history |
message = input_box.get("1.0", "end-1c") |
t1 = time.time() |
if message == 'clear': |
history = [] |
else: |
response, new_history = model.chat(tokenizer, message, history) |
response = response.replace(" ",'').replace("\n",'.') |
text = infer(response) |
text = text.replace('[JA]','').replace('[ZH]','') |
chat_box.configure(state='normal') |
chat_box.insert(tk.END, "You: " + message + "\n") |
chat_box.insert(tk.END, "Tamao: " + text + "\n") |
chat_box.configure(state='disabled') |
input_box.delete("1.0", tk.END) |
t2 = time.time() |
print("总共耗时:",(t2 - t1),"s") |
root = tk.Tk() |
root.title("Tamao") |
chat_box = scrolledtext.ScrolledText(root, width=50, height=10) |
chat_box.configure(state='disabled') |
chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True) |
input_frame = tk.Frame(root) |
input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10) |
input_box = tk.Text(input_frame, height=3, width=50) |
input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True) |
send_button = tk.Button(input_frame, text="Send", command=send_message) |
send_button.pack(side=tk.RIGHT, padx=10) |
root.mainloop() |