|
import argparse |
|
from text import text_to_sequence |
|
import numpy as np |
|
from scipy.io import wavfile |
|
import torch |
|
import json |
|
import commons |
|
import utils |
|
import sys |
|
import pathlib |
|
from flask import Flask, request |
|
import threading |
|
import onnxruntime as ort |
|
import time |
|
from pydub import AudioSegment |
|
import io |
|
import os |
|
from transformers import AutoTokenizer, AutoModel |
|
import tkinter as tk |
|
from tkinter import scrolledtext |
|
from scipy.io.wavfile import write |
|
def get_args(): |
|
parser = argparse.ArgumentParser(description='inference') |
|
parser.add_argument('--onnx_model', default = './moe/model.onnx') |
|
parser.add_argument('--cfg', default="./moe/config_v.json") |
|
parser.add_argument('--outdir', default="./moe", |
|
help='ouput folder') |
|
parser.add_argument('--audio', |
|
type=str, |
|
help='你要替换的音频文件的,假设这些音频文件为temp1、temp2、temp3......', |
|
default = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav') |
|
parser.add_argument('--ChatGLM',default = "./moe", |
|
help='https://github.com/THUDM/ChatGLM-6B') |
|
args = parser.parse_args() |
|
return args |
|
|
|
def to_numpy(tensor: torch.Tensor): |
|
return tensor.detach().cpu().numpy() if tensor.requires_grad \ |
|
else tensor.detach().numpy() |
|
|
|
def get_symbols_from_json(path): |
|
import os |
|
assert os.path.isfile(path) |
|
with open(path, 'r') as f: |
|
data = json.load(f) |
|
return data['symbols'] |
|
|
|
args = get_args() |
|
symbols = get_symbols_from_json(args.cfg) |
|
phone_dict = { |
|
symbol: i for i, symbol in enumerate(symbols) |
|
} |
|
hps = utils.get_hparams_from_file(args.cfg) |
|
ort_sess = ort.InferenceSession(args.onnx_model) |
|
|
|
def is_japanese(string): |
|
for ch in string: |
|
if ord(ch) > 0x3040 and ord(ch) < 0x30FF: |
|
return True |
|
return False |
|
|
|
def infer(text): |
|
|
|
sid = 7 |
|
text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]" |
|
|
|
seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners) |
|
if hps.data.add_blank: |
|
seq = commons.intersperse(seq, 0) |
|
with torch.no_grad(): |
|
x = np.array([seq], dtype=np.int64) |
|
x_len = np.array([x.shape[1]], dtype=np.int64) |
|
sid = np.array([sid], dtype=np.int64) |
|
scales = np.array([0.667, 0.7, 1], dtype=np.float32) |
|
scales.resize(1, 3) |
|
ort_inputs = { |
|
'input': x, |
|
'input_lengths': x_len, |
|
'scales': scales, |
|
'sid': sid |
|
} |
|
t1 = time.time() |
|
audio = np.squeeze(ort_sess.run(None, ort_inputs)) |
|
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6 |
|
audio = np.clip(audio, -32767.0, 32767.0) |
|
bytes_wav = bytes() |
|
byte_io = io.BytesIO(bytes_wav) |
|
wavfile.write(args.audio + '.wav',hps.data.sampling_rate, audio.astype(np.int16)) |
|
i = 0 |
|
while i < 19: |
|
i +=1 |
|
cmd = 'ffmpeg -y -i ' + args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i)) |
|
os.system(cmd) |
|
t2 = time.time() |
|
print("推理耗时:",(t2 - t1),"s") |
|
return text |
|
tokenizer = AutoTokenizer.from_pretrained(args.ChatGLM, trust_remote_code=True) |
|
|
|
model = AutoModel.from_pretrained(args.ChatGLM, trust_remote_code=True).half().quantize(4).cuda() |
|
history = [] |
|
def send_message(): |
|
global history |
|
message = input_box.get("1.0", "end-1c") |
|
t1 = time.time() |
|
if message == 'clear': |
|
history = [] |
|
else: |
|
response, new_history = model.chat(tokenizer, message, history) |
|
response = response.replace(" ",'').replace("\n",'.') |
|
text = infer(response) |
|
text = text.replace('[JA]','').replace('[ZH]','') |
|
chat_box.configure(state='normal') |
|
chat_box.insert(tk.END, "You: " + message + "\n") |
|
chat_box.insert(tk.END, "Tamao: " + text + "\n") |
|
chat_box.configure(state='disabled') |
|
input_box.delete("1.0", tk.END) |
|
t2 = time.time() |
|
print("总共耗时:",(t2 - t1),"s") |
|
|
|
root = tk.Tk() |
|
root.title("Tamao") |
|
|
|
|
|
chat_box = scrolledtext.ScrolledText(root, width=50, height=10) |
|
chat_box.configure(state='disabled') |
|
chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True) |
|
|
|
|
|
input_frame = tk.Frame(root) |
|
input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10) |
|
input_box = tk.Text(input_frame, height=3, width=50) |
|
input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True) |
|
send_button = tk.Button(input_frame, text="Send", command=send_message) |
|
send_button.pack(side=tk.RIGHT, padx=10) |
|
|
|
|
|
root.mainloop() |