Update voice_processing.py
Browse files- voice_processing.py +57 -2
voice_processing.py
CHANGED
@@ -47,10 +47,65 @@ def get_unique_filename(extension):
|
|
47 |
return f"{uuid.uuid4()}.{extension}"
|
48 |
|
49 |
def model_data(model_name):
|
50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
def load_hubert():
|
53 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
def get_model_names():
|
56 |
model_root = "weights"
|
|
|
47 |
return f"{uuid.uuid4()}.{extension}"
|
48 |
|
49 |
def model_data(model_name):
|
50 |
+
pth_path = [
|
51 |
+
f"{model_root}/{model_name}/{f}"
|
52 |
+
for f in os.listdir(f"{model_root}/{model_name}")
|
53 |
+
if f.endswith(".pth")
|
54 |
+
][0]
|
55 |
+
print(f"Loading {pth_path}")
|
56 |
+
cpt = torch.load(pth_path, map_location="cpu")
|
57 |
+
tgt_sr = cpt["config"][-1]
|
58 |
+
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
59 |
+
if_f0 = cpt.get("f0", 1)
|
60 |
+
version = cpt.get("version", "v1")
|
61 |
+
if version == "v1":
|
62 |
+
if if_f0 == 1:
|
63 |
+
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half)
|
64 |
+
else:
|
65 |
+
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
|
66 |
+
elif version == "v2":
|
67 |
+
if if_f0 == 1:
|
68 |
+
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half)
|
69 |
+
else:
|
70 |
+
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
|
71 |
+
else:
|
72 |
+
raise ValueError("Unknown version")
|
73 |
+
del net_g.enc_q
|
74 |
+
net_g.load_state_dict(cpt["weight"], strict=False)
|
75 |
+
print("Model loaded")
|
76 |
+
net_g.eval().to(config.device)
|
77 |
+
if config.is_half:
|
78 |
+
net_g = net_g.half()
|
79 |
+
else:
|
80 |
+
net_g = net_g.float()
|
81 |
+
vc = VC(tgt_sr, config)
|
82 |
+
|
83 |
+
index_files = [
|
84 |
+
f"{model_root}/{model_name}/{f}"
|
85 |
+
for f in os.listdir(f"{model_root}/{model_name}")
|
86 |
+
if f.endswith(".index")
|
87 |
+
]
|
88 |
+
if len(index_files) == 0:
|
89 |
+
print("No index file found")
|
90 |
+
index_file = ""
|
91 |
+
else:
|
92 |
+
index_file = index_files[0]
|
93 |
+
print(f"Index file found: {index_file}")
|
94 |
+
|
95 |
+
return tgt_sr, net_g, vc, version, index_file, if_f0
|
96 |
|
97 |
def load_hubert():
|
98 |
+
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
99 |
+
["hubert_base.pt"],
|
100 |
+
suffix="",
|
101 |
+
)
|
102 |
+
hubert_model = models[0]
|
103 |
+
hubert_model = hubert_model.to(config.device)
|
104 |
+
if config.is_half:
|
105 |
+
hubert_model = hubert_model.half()
|
106 |
+
else:
|
107 |
+
hubert_model = hubert_model.float()
|
108 |
+
return hubert_model.eval()
|
109 |
|
110 |
def get_model_names():
|
111 |
model_root = "weights"
|