smjain commited on
Commit
0ccc2cf
1 Parent(s): 4afb64f

Upload extract_feature_print.py

Browse files
Files changed (1) hide show
  1. extract_feature_print.py +137 -0
extract_feature_print.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import traceback
4
+
5
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
6
+ os.environ["PYTORCH_MPS_HIGH_WATERMARK_RATIO"] = "0.0"
7
+
8
+ device = sys.argv[1]
9
+ n_part = int(sys.argv[2])
10
+ i_part = int(sys.argv[3])
11
+ if len(sys.argv) == 6:
12
+ exp_dir = sys.argv[4]
13
+ version = sys.argv[5]
14
+ else:
15
+ i_gpu = sys.argv[4]
16
+ exp_dir = sys.argv[5]
17
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(i_gpu)
18
+ version = sys.argv[6]
19
+ import fairseq
20
+ import numpy as np
21
+ import soundfile as sf
22
+ import torch
23
+ import torch.nn.functional as F
24
+
25
+ if "privateuseone" not in device:
26
+ device = "cpu"
27
+ if torch.cuda.is_available():
28
+ device = "cuda"
29
+ elif torch.backends.mps.is_available():
30
+ device = "mps"
31
+ else:
32
+ import torch_directml
33
+
34
+ device = torch_directml.device(torch_directml.default_device())
35
+
36
+ def forward_dml(ctx, x, scale):
37
+ ctx.scale = scale
38
+ res = x.clone().detach()
39
+ return res
40
+
41
+ fairseq.modules.grad_multiply.GradMultiply.forward = forward_dml
42
+
43
+ f = open("%s/extract_f0_feature.log" % exp_dir, "a+")
44
+
45
+
46
+ def printt(strr):
47
+ print(strr)
48
+ f.write("%s\n" % strr)
49
+ f.flush()
50
+
51
+
52
+ printt(sys.argv)
53
+ model_path = "assets/hubert/hubert_base.pt"
54
+
55
+ printt(exp_dir)
56
+ wavPath = "%s/1_16k_wavs" % exp_dir
57
+ outPath = (
58
+ "%s/3_feature256" % exp_dir if version == "v1" else "%s/3_feature768" % exp_dir
59
+ )
60
+ os.makedirs(outPath, exist_ok=True)
61
+
62
+
63
+ # wave must be 16k, hop_size=320
64
+ def readwave(wav_path, normalize=False):
65
+ wav, sr = sf.read(wav_path)
66
+ assert sr == 16000
67
+ feats = torch.from_numpy(wav).float()
68
+ if feats.dim() == 2: # double channels
69
+ feats = feats.mean(-1)
70
+ assert feats.dim() == 1, feats.dim()
71
+ if normalize:
72
+ with torch.no_grad():
73
+ feats = F.layer_norm(feats, feats.shape)
74
+ feats = feats.view(1, -1)
75
+ return feats
76
+
77
+
78
+ # HuBERT model
79
+ printt("load model(s) from {}".format(model_path))
80
+ # if hubert model is exist
81
+ if os.access(model_path, os.F_OK) == False:
82
+ printt(
83
+ "Error: Extracting is shut down because %s does not exist, you may download it from https://huggingface.co/lj1995/VoiceConversionWebUI/tree/main"
84
+ % model_path
85
+ )
86
+ exit(0)
87
+ models, saved_cfg, task = fairseq.checkpoint_utils.load_model_ensemble_and_task(
88
+ [model_path],
89
+ suffix="",
90
+ )
91
+ model = models[0]
92
+ model = model.to(device)
93
+ printt("move model to %s" % device)
94
+ if device not in ["mps", "cpu"]:
95
+ model = model.half()
96
+ model.eval()
97
+
98
+ todo = sorted(list(os.listdir(wavPath)))[i_part::n_part]
99
+ n = max(1, len(todo) // 10) # 最多打印十条
100
+ if len(todo) == 0:
101
+ printt("no-feature-todo")
102
+ else:
103
+ printt("all-feature-%s" % len(todo))
104
+ for idx, file in enumerate(todo):
105
+ try:
106
+ if file.endswith(".wav"):
107
+ wav_path = "%s/%s" % (wavPath, file)
108
+ out_path = "%s/%s" % (outPath, file.replace("wav", "npy"))
109
+
110
+ if os.path.exists(out_path):
111
+ continue
112
+
113
+ feats = readwave(wav_path, normalize=saved_cfg.task.normalize)
114
+ padding_mask = torch.BoolTensor(feats.shape).fill_(False)
115
+ inputs = {
116
+ "source": feats.half().to(device)
117
+ if device not in ["mps", "cpu"]
118
+ else feats.to(device),
119
+ "padding_mask": padding_mask.to(device),
120
+ "output_layer": 9 if version == "v1" else 12, # layer 9
121
+ }
122
+ with torch.no_grad():
123
+ logits = model.extract_features(**inputs)
124
+ feats = (
125
+ model.final_proj(logits[0]) if version == "v1" else logits[0]
126
+ )
127
+
128
+ feats = feats.squeeze(0).float().cpu().numpy()
129
+ if np.isnan(feats).sum() == 0:
130
+ np.save(out_path, feats, allow_pickle=False)
131
+ else:
132
+ printt("%s-contains nan" % file)
133
+ if idx % n == 0:
134
+ printt("now-%s,all-%s,%s,%s" % (len(todo), idx, file, feats.shape))
135
+ except:
136
+ printt(traceback.format_exc())
137
+ printt("all-feature-done")