Spaces:
Runtime error
Runtime error
update infer_contrast
Browse files
.ipynb_checkpoints/infer_contrast-checkpoint.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import functools
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
|
7 |
+
from utils.reader import load_audio
|
8 |
+
from utils.utility import add_arguments, print_arguments
|
9 |
+
|
10 |
+
parser = argparse.ArgumentParser(description=__doc__)
|
11 |
+
add_arg = functools.partial(add_arguments, argparser=parser)
|
12 |
+
add_arg('threshold', float, 0.71, '判断是否为同一个人的阈值')
|
13 |
+
add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状')
|
14 |
+
add_arg('model_path', str, 'models_large/resnet34.pth', '预测模型的路径')
|
15 |
+
# args = parser.parse_args()
|
16 |
+
args =parser.parse_known_args()[0]
|
17 |
+
|
18 |
+
print_arguments(args)
|
19 |
+
print(torch.cuda.is_available())
|
20 |
+
device = torch.device("cpu")
|
21 |
+
|
22 |
+
# 加载模型
|
23 |
+
# model = torch.jit.load(args.model_path)
|
24 |
+
model = torch.jit.load(args.model_path,map_location="cpu")
|
25 |
+
# model.to(device)
|
26 |
+
model.eval()
|
27 |
+
|
28 |
+
|
29 |
+
# 预测音频
|
30 |
+
def infer(audio_path):
|
31 |
+
input_shape = eval(args.input_shape)
|
32 |
+
data = load_audio(audio_path, mode='infer', spec_len=input_shape[2])
|
33 |
+
data = data[np.newaxis, :]
|
34 |
+
data = torch.tensor(data, dtype=torch.float32)
|
35 |
+
# 执行预测
|
36 |
+
feature = model(data)
|
37 |
+
return feature.data.cpu().numpy()
|
38 |
+
|
39 |
+
|
40 |
+
def run(audio1,audio2):
|
41 |
+
# 要预测的两个人的音频文件
|
42 |
+
feature1 = infer(audio1)[0]
|
43 |
+
feature2 = infer(audio2)[0]
|
44 |
+
# 对角余弦值
|
45 |
+
dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
|
46 |
+
if dist > args.threshold:
|
47 |
+
result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
|
48 |
+
else:
|
49 |
+
result = "Speaker1 和 Speaker2 不是同一个人,相似度为:%f" % (dist)
|
50 |
+
|
51 |
+
return result
|
infer_contrast.py
CHANGED
@@ -46,6 +46,6 @@ def run(audio1,audio2):
|
|
46 |
if dist > args.threshold:
|
47 |
result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
|
48 |
else:
|
49 |
-
result = "Speaker1 和 Speaker2
|
50 |
|
51 |
return result
|
|
|
46 |
if dist > args.threshold:
|
47 |
result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
|
48 |
else:
|
49 |
+
result = "Speaker1 和 Speaker2 不是同一个人,相似度为:%f" % (dist)
|
50 |
|
51 |
return result
|