YuAnthony commited on
Commit
6bb0077
·
1 Parent(s): c44ab12

update infer_contrast

Browse files
.ipynb_checkpoints/infer_contrast-checkpoint.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+
4
+ import numpy as np
5
+ import torch
6
+
7
+ from utils.reader import load_audio
8
+ from utils.utility import add_arguments, print_arguments
9
+
10
+ parser = argparse.ArgumentParser(description=__doc__)
11
+ add_arg = functools.partial(add_arguments, argparser=parser)
12
+ add_arg('threshold', float, 0.71, '判断是否为同一个人的阈值')
13
+ add_arg('input_shape', str, '(1, 257, 257)', '数据输入的形状')
14
+ add_arg('model_path', str, 'models_large/resnet34.pth', '预测模型的路径')
15
+ # args = parser.parse_args()
16
+ args =parser.parse_known_args()[0]
17
+
18
+ print_arguments(args)
19
+ print(torch.cuda.is_available())
20
+ device = torch.device("cpu")
21
+
22
+ # 加载模型
23
+ # model = torch.jit.load(args.model_path)
24
+ model = torch.jit.load(args.model_path,map_location="cpu")
25
+ # model.to(device)
26
+ model.eval()
27
+
28
+
29
+ # 预测音频
30
+ def infer(audio_path):
31
+ input_shape = eval(args.input_shape)
32
+ data = load_audio(audio_path, mode='infer', spec_len=input_shape[2])
33
+ data = data[np.newaxis, :]
34
+ data = torch.tensor(data, dtype=torch.float32)
35
+ # 执行预测
36
+ feature = model(data)
37
+ return feature.data.cpu().numpy()
38
+
39
+
40
+ def run(audio1,audio2):
41
+ # 要预测的两个人的音频文件
42
+ feature1 = infer(audio1)[0]
43
+ feature2 = infer(audio2)[0]
44
+ # 对角余弦值
45
+ dist = np.dot(feature1, feature2) / (np.linalg.norm(feature1) * np.linalg.norm(feature2))
46
+ if dist > args.threshold:
47
+ result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
48
+ else:
49
+ result = "Speaker1 和 Speaker2 不是同一个人,相似度为:%f" % (dist)
50
+
51
+ return result
infer_contrast.py CHANGED
@@ -46,6 +46,6 @@ def run(audio1,audio2):
46
  if dist > args.threshold:
47
  result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
48
  else:
49
- result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
50
 
51
  return result
 
46
  if dist > args.threshold:
47
  result = "Speaker1 和 Speaker2 为同一个人,相似度为:%f" % (dist)
48
  else:
49
+ result = "Speaker1 和 Speaker2 不是同一个人,相似度为:%f" % (dist)
50
 
51
  return result