Mahiruoshi's picture
Upload 73 files
4de73fc
raw
history blame
5.09 kB
import argparse
from text import text_to_sequence
import numpy as np
from scipy.io import wavfile
import torch
import json
import commons
import utils
import sys
import pathlib
from flask import Flask, request
import threading
import onnxruntime as ort
import time
from pydub import AudioSegment
import io
import os
from transformers import AutoTokenizer, AutoModel
import tkinter as tk
from tkinter import scrolledtext
from scipy.io.wavfile import write
def get_args():
parser = argparse.ArgumentParser(description='inference')
parser.add_argument('--onnx_model', default = './moe/model.onnx')
parser.add_argument('--cfg', default="./moe/config_v.json")
parser.add_argument('--outdir', default="./moe",
help='ouput folder')
parser.add_argument('--audio',
type=str,
help='你要替换的音频文件的,假设这些音频文件为temp1、temp2、temp3......',
default = 'D:/app_develop/live2d_whole/2010002/sounds/temp.wav')
parser.add_argument('--ChatGLM',default = "./moe",
help='https://github.com/THUDM/ChatGLM-6B')
args = parser.parse_args()
return args
def to_numpy(tensor: torch.Tensor):
return tensor.detach().cpu().numpy() if tensor.requires_grad \
else tensor.detach().numpy()
def get_symbols_from_json(path):
import os
assert os.path.isfile(path)
with open(path, 'r') as f:
data = json.load(f)
return data['symbols']
args = get_args()
symbols = get_symbols_from_json(args.cfg)
phone_dict = {
symbol: i for i, symbol in enumerate(symbols)
}
hps = utils.get_hparams_from_file(args.cfg)
ort_sess = ort.InferenceSession(args.onnx_model)
def is_japanese(string):
for ch in string:
if ord(ch) > 0x3040 and ord(ch) < 0x30FF:
return True
return False
def infer(text):
#选择你想要的角色
sid = 7
text = f"[JA]{text}[JA]" if is_japanese(text) else f"[ZH]{text}[ZH]"
#seq = text_to_sequence(text, symbols=hps.symbols, cleaner_names=hps.data.text_cleaners)
seq = text_to_sequence(text, cleaner_names=hps.data.text_cleaners)
if hps.data.add_blank:
seq = commons.intersperse(seq, 0)
with torch.no_grad():
x = np.array([seq], dtype=np.int64)
x_len = np.array([x.shape[1]], dtype=np.int64)
sid = np.array([sid], dtype=np.int64)
scales = np.array([0.667, 0.7, 1], dtype=np.float32)
scales.resize(1, 3)
ort_inputs = {
'input': x,
'input_lengths': x_len,
'scales': scales,
'sid': sid
}
t1 = time.time()
audio = np.squeeze(ort_sess.run(None, ort_inputs))
audio *= 32767.0 / max(0.01, np.max(np.abs(audio))) * 0.6
audio = np.clip(audio, -32767.0, 32767.0)
bytes_wav = bytes()
byte_io = io.BytesIO(bytes_wav)
wavfile.write(args.audio + '.wav',hps.data.sampling_rate, audio.astype(np.int16))
i = 0
while i < 19:
i +=1
cmd = 'ffmpeg -y -i ' + args.audio + '.wav' + ' -ar 44100 '+ args.audio.replace('temp','temp'+str(i))
os.system(cmd)
t2 = time.time()
print("推理耗时:",(t2 - t1),"s")
return text
tokenizer = AutoTokenizer.from_pretrained(args.ChatGLM, trust_remote_code=True)
#8G GPU
model = AutoModel.from_pretrained(args.ChatGLM, trust_remote_code=True).half().quantize(4).cuda()
history = []
def send_message():
global history
message = input_box.get("1.0", "end-1c") # 获取用户输入的文本
t1 = time.time()
if message == 'clear':
history = []
else:
response, new_history = model.chat(tokenizer, message, history)
response = response.replace(" ",'').replace("\n",'.')
text = infer(response)
text = text.replace('[JA]','').replace('[ZH]','')
chat_box.configure(state='normal') # 配置聊天框为可写状态
chat_box.insert(tk.END, "You: " + message + "\n") # 在聊天框中显示用户输入的文本
chat_box.insert(tk.END, "Tamao: " + text + "\n") # 在聊天框中显示 chatbot 的回复
chat_box.configure(state='disabled') # 配置聊天框为只读状态
input_box.delete("1.0", tk.END) # 清空输入框
t2 = time.time()
print("总共耗时:",(t2 - t1),"s")
root = tk.Tk()
root.title("Tamao")
# 创建聊天框
chat_box = scrolledtext.ScrolledText(root, width=50, height=10)
chat_box.configure(state='disabled') # 聊天框一开始是只读状态
chat_box.pack(side=tk.TOP, fill=tk.BOTH, padx=10, pady=10, expand=True)
# 创建输入框和发送按钮
input_frame = tk.Frame(root)
input_frame.pack(side=tk.BOTTOM, fill=tk.X, padx=10, pady=10)
input_box = tk.Text(input_frame, height=3, width=50) # 设置输入框宽度为50
input_box.pack(side=tk.LEFT, fill=tk.X, padx=10, expand=True)
send_button = tk.Button(input_frame, text="Send", command=send_message)
send_button.pack(side=tk.RIGHT, padx=10)
# 运行主程序
root.mainloop()