Voice-Recognition / infer_contrast.py
YuAnthony's picture
update files
dd65803
import argparse
import functools
import numpy as np
import torch
from utils.reader import load_audio
from utils.utility import add_arguments, print_arguments
parser = argparse.ArgumentParser(description=__doc__)
add_arg = functools.partial(add_arguments, argparser=parser)
add_arg('threshold', float, 0.71, '判断是否为同一个人的阈值')
add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状')
add_arg('model_path', str, 'models_large/resnet34.pth', '预测模型的路径')
# args = parser.parse_args()
args =parser.parse_known_args()[0]
print_arguments(args)
print(torch.cuda.is_available())
device = torch.device("cpu")
# 加载模型
# model = torch.jit.load(args.model_path)
model = torch.jit.load(args.model_path,map_location="cpu")
# model.to(device)
model.eval()
# 预测音频
def infer(audio_path):
input_shape = eval(args.input_shape)
data = load_audio(audio_path, mode='infer', spec_len=input_shape[2])
data = data[np.newaxis, :]
data = torch.tensor(data, dtype=torch.float32)
# 执行预测
feature = model(data)
return feature.data.cpu().numpy()
def run(audio1,audio2):
# 要预测的两个人的音频文件
feature1 = infer(audio1)[0]
feature2 = infer(audio2)[0]
# 对角余弦值
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
return dist