import argparse import functools import numpy as np import torch from utils.reader import load_audio from utils.utility import add_arguments, print_arguments parser = argparse.ArgumentParser(description=__doc__) add_arg = functools.partial(add_arguments, argparser=parser) add_arg('threshold', float, 0.71, '判断是否为同一个人的阈值') add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状') add_arg('model_path', str, 'models_large/resnet34.pth', '预测模型的路径') # args = parser.parse_args() args =parser.parse_known_args()[0] print_arguments(args) print(torch.cuda.is_available()) device = torch.device("cpu") # 加载模型 # model = torch.jit.load(args.model_path) model = torch.jit.load(args.model_path,map_location="cpu") # model.to(device) model.eval() # 预测音频 def infer(audio_path): input_shape = eval(args.input_shape) data = load_audio(audio_path, mode='infer', spec_len=input_shape[2]) data = data[np.newaxis, :] data = torch.tensor(data, dtype=torch.float32) # 执行预测 feature = model(data) return feature.data.cpu().numpy() def run(audio1,audio2): # 要预测的两个人的音频文件 feature1 = infer(audio1)[0] feature2 = infer(audio2)[0] # 对角余弦值 dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2)) return dist