Upload 7 files
Browse files- infer/modules/vc/__init__.py +5 -0
- infer/modules/vc/hash.py +202 -0
- infer/modules/vc/info.py +84 -0
- infer/modules/vc/lgdsng.npz +3 -0
- infer/modules/vc/modules.py +35 -60
- infer/modules/vc/pipeline.py +83 -96
- infer/modules/vc/utils.py +5 -4
infer/modules/vc/__init__.py
CHANGED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .pipeline import Pipeline
|
2 |
+
from .modules import VC
|
3 |
+
from .utils import get_index_path_from_model, load_hubert
|
4 |
+
from .info import show_info
|
5 |
+
from .hash import model_hash_ckpt, hash_id, hash_similarity
|
infer/modules/vc/hash.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import hashlib
|
4 |
+
import pathlib
|
5 |
+
from scipy.fft import fft
|
6 |
+
from pybase16384 import encode_to_string, decode_from_string
|
7 |
+
|
8 |
+
from configs import CPUConfig, singleton_variable
|
9 |
+
from rvc.synthesizer import get_synthesizer
|
10 |
+
|
11 |
+
from .pipeline import Pipeline
|
12 |
+
from .utils import load_hubert
|
13 |
+
|
14 |
+
|
15 |
+
class TorchSeedContext:
|
16 |
+
def __init__(self, seed):
|
17 |
+
self.seed = seed
|
18 |
+
self.state = None
|
19 |
+
|
20 |
+
def __enter__(self):
|
21 |
+
self.state = torch.random.get_rng_state()
|
22 |
+
torch.manual_seed(self.seed)
|
23 |
+
|
24 |
+
def __exit__(self, type, value, traceback):
|
25 |
+
torch.random.set_rng_state(self.state)
|
26 |
+
|
27 |
+
|
28 |
+
half_hash_len = 512
|
29 |
+
expand_factor = 65536 * 8
|
30 |
+
|
31 |
+
|
32 |
+
@singleton_variable
|
33 |
+
def original_audio_storage():
|
34 |
+
return np.load(pathlib.Path(__file__).parent / "lgdsng.npz")
|
35 |
+
|
36 |
+
|
37 |
+
@singleton_variable
|
38 |
+
def original_audio():
|
39 |
+
return original_audio_storage()["a"]
|
40 |
+
|
41 |
+
|
42 |
+
@singleton_variable
|
43 |
+
def original_audio_time_minus():
|
44 |
+
return original_audio_storage()["t"]
|
45 |
+
|
46 |
+
|
47 |
+
@singleton_variable
|
48 |
+
def original_audio_freq_minus():
|
49 |
+
return original_audio_storage()["f"]
|
50 |
+
|
51 |
+
|
52 |
+
@singleton_variable
|
53 |
+
def original_rmvpe_f0():
|
54 |
+
x = original_audio_storage()
|
55 |
+
return x["pitch"], x["pitchf"]
|
56 |
+
|
57 |
+
|
58 |
+
def _cut_u16(n):
|
59 |
+
if n > 16384:
|
60 |
+
n = 16384 + 16384 * (1 - np.exp((16384 - n) / expand_factor))
|
61 |
+
elif n < -16384:
|
62 |
+
n = -16384 - 16384 * (1 - np.exp((n + 16384) / expand_factor))
|
63 |
+
return n
|
64 |
+
|
65 |
+
|
66 |
+
# wave_hash will change time_field, use carefully
|
67 |
+
def wave_hash(time_field):
|
68 |
+
np.divide(time_field, np.abs(time_field).max(), time_field)
|
69 |
+
if len(time_field) != 48000:
|
70 |
+
raise Exception("time not hashable")
|
71 |
+
freq_field = fft(time_field)
|
72 |
+
if len(freq_field) != 48000:
|
73 |
+
raise Exception("freq not hashable")
|
74 |
+
np.add(time_field, original_audio_time_minus(), out=time_field)
|
75 |
+
np.add(freq_field, original_audio_freq_minus(), out=freq_field)
|
76 |
+
hash = np.zeros(half_hash_len // 2 * 2, dtype=">i2")
|
77 |
+
d = 375 * 512 // half_hash_len
|
78 |
+
for i in range(half_hash_len // 4):
|
79 |
+
a = i * 2
|
80 |
+
b = a + 1
|
81 |
+
x = a + half_hash_len // 2
|
82 |
+
y = x + 1
|
83 |
+
s = np.average(freq_field[i * d : (i + 1) * d])
|
84 |
+
hash[a] = np.int16(_cut_u16(round(32768 * np.real(s))))
|
85 |
+
hash[b] = np.int16(_cut_u16(round(32768 * np.imag(s))))
|
86 |
+
hash[x] = np.int16(
|
87 |
+
_cut_u16(round(32768 * np.sum(time_field[i * d : i * d + d // 2])))
|
88 |
+
)
|
89 |
+
hash[y] = np.int16(
|
90 |
+
_cut_u16(round(32768 * np.sum(time_field[i * d + d // 2 : (i + 1) * d])))
|
91 |
+
)
|
92 |
+
return encode_to_string(hash.tobytes())
|
93 |
+
|
94 |
+
|
95 |
+
def model_hash(config, tgt_sr, net_g, if_f0, version):
|
96 |
+
pipeline = Pipeline(tgt_sr, config)
|
97 |
+
audio = original_audio()
|
98 |
+
hbt = load_hubert(config.device, config.is_half)
|
99 |
+
audio_opt = pipeline.pipeline(
|
100 |
+
hbt,
|
101 |
+
net_g,
|
102 |
+
0,
|
103 |
+
audio,
|
104 |
+
[0, 0, 0],
|
105 |
+
6,
|
106 |
+
original_rmvpe_f0(),
|
107 |
+
"",
|
108 |
+
0,
|
109 |
+
2 if if_f0 else 0,
|
110 |
+
3,
|
111 |
+
tgt_sr,
|
112 |
+
16000,
|
113 |
+
0.25,
|
114 |
+
version,
|
115 |
+
0.33,
|
116 |
+
)
|
117 |
+
del hbt
|
118 |
+
opt_len = len(audio_opt)
|
119 |
+
diff = 48000 - opt_len
|
120 |
+
if diff > 0:
|
121 |
+
audio_opt = np.pad(audio_opt, (diff, 0))
|
122 |
+
elif diff < 0:
|
123 |
+
n = diff // 2
|
124 |
+
n = -n
|
125 |
+
audio_opt = audio_opt[n:-n]
|
126 |
+
h = wave_hash(audio_opt)
|
127 |
+
del pipeline, audio_opt
|
128 |
+
return h
|
129 |
+
|
130 |
+
|
131 |
+
def model_hash_ckpt(cpt):
|
132 |
+
config = CPUConfig()
|
133 |
+
|
134 |
+
with TorchSeedContext(114514):
|
135 |
+
net_g, cpt = get_synthesizer(cpt, config.device)
|
136 |
+
tgt_sr = cpt["config"][-1]
|
137 |
+
if_f0 = cpt.get("f0", 1)
|
138 |
+
version = cpt.get("version", "v1")
|
139 |
+
|
140 |
+
if config.is_half:
|
141 |
+
net_g = net_g.half()
|
142 |
+
else:
|
143 |
+
net_g = net_g.float()
|
144 |
+
|
145 |
+
h = model_hash(config, tgt_sr, net_g, if_f0, version)
|
146 |
+
|
147 |
+
del net_g
|
148 |
+
|
149 |
+
return h
|
150 |
+
|
151 |
+
|
152 |
+
def model_hash_from(path):
|
153 |
+
cpt = torch.load(path, map_location="cpu")
|
154 |
+
h = model_hash_ckpt(cpt)
|
155 |
+
del cpt
|
156 |
+
return h
|
157 |
+
|
158 |
+
|
159 |
+
def _extend_difference(n, a, b):
|
160 |
+
if n < a:
|
161 |
+
n = a
|
162 |
+
elif n > b:
|
163 |
+
n = b
|
164 |
+
n -= a
|
165 |
+
n /= b - a
|
166 |
+
return n
|
167 |
+
|
168 |
+
|
169 |
+
def hash_similarity(h1: str, h2: str) -> float:
|
170 |
+
try:
|
171 |
+
h1b, h2b = decode_from_string(h1), decode_from_string(h2)
|
172 |
+
if len(h1b) != half_hash_len * 2 or len(h2b) != half_hash_len * 2:
|
173 |
+
raise Exception("invalid hash length")
|
174 |
+
h1n, h2n = np.frombuffer(h1b, dtype=">i2"), np.frombuffer(h2b, dtype=">i2")
|
175 |
+
d = 0
|
176 |
+
for i in range(half_hash_len // 4):
|
177 |
+
a = i * 2
|
178 |
+
b = a + 1
|
179 |
+
ax = complex(h1n[a], h1n[b])
|
180 |
+
bx = complex(h2n[a], h2n[b])
|
181 |
+
if abs(ax) == 0 or abs(bx) == 0:
|
182 |
+
continue
|
183 |
+
d += np.abs(ax - bx)
|
184 |
+
frac = np.linalg.norm(h1n) * np.linalg.norm(h2n)
|
185 |
+
cosine = (
|
186 |
+
np.dot(h1n.astype(np.float32), h2n.astype(np.float32)) / frac
|
187 |
+
if frac != 0
|
188 |
+
else 1.0
|
189 |
+
)
|
190 |
+
distance = _extend_difference(np.exp(-d / expand_factor), 0.5, 1.0)
|
191 |
+
return round((abs(cosine) + distance) / 2, 6)
|
192 |
+
except Exception as e:
|
193 |
+
return str(e)
|
194 |
+
|
195 |
+
|
196 |
+
def hash_id(h: str) -> str:
|
197 |
+
d = decode_from_string(h)
|
198 |
+
if len(d) != half_hash_len * 2:
|
199 |
+
return "invalid hash length"
|
200 |
+
return encode_to_string(
|
201 |
+
np.frombuffer(d, dtype=np.uint64).sum(keepdims=True).tobytes()
|
202 |
+
)[:-2] + encode_to_string(hashlib.md5(d).digest()[:7])
|
infer/modules/vc/info.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import traceback
|
2 |
+
from i18n.i18n import I18nAuto
|
3 |
+
from datetime import datetime
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .hash import model_hash_ckpt, hash_id, hash_similarity
|
7 |
+
|
8 |
+
i18n = I18nAuto()
|
9 |
+
|
10 |
+
|
11 |
+
def show_model_info(cpt, show_long_id=False):
|
12 |
+
try:
|
13 |
+
h = model_hash_ckpt(cpt)
|
14 |
+
id = hash_id(h)
|
15 |
+
idread = cpt.get("id", "None")
|
16 |
+
hread = cpt.get("hash", "None")
|
17 |
+
if id != idread:
|
18 |
+
id += (
|
19 |
+
"("
|
20 |
+
+ i18n("Actually calculated")
|
21 |
+
+ "), "
|
22 |
+
+ idread
|
23 |
+
+ "("
|
24 |
+
+ i18n("Read from model")
|
25 |
+
+ ")"
|
26 |
+
)
|
27 |
+
sim = hash_similarity(h, hread)
|
28 |
+
if not isinstance(sim, str):
|
29 |
+
sim = "%.2f%%" % (sim * 100)
|
30 |
+
if not show_long_id:
|
31 |
+
h = i18n("Hidden")
|
32 |
+
if h != hread:
|
33 |
+
h = i18n("Similarity") + " " + sim + " -> " + h
|
34 |
+
elif h != hread:
|
35 |
+
h = (
|
36 |
+
i18n("Similarity")
|
37 |
+
+ " "
|
38 |
+
+ sim
|
39 |
+
+ " -> "
|
40 |
+
+ h
|
41 |
+
+ "("
|
42 |
+
+ i18n("Actually calculated")
|
43 |
+
+ "), "
|
44 |
+
+ hread
|
45 |
+
+ "("
|
46 |
+
+ i18n("Read from model")
|
47 |
+
+ ")"
|
48 |
+
)
|
49 |
+
txt = f"""{i18n("Model name")}: %s
|
50 |
+
{i18n("Sealing date")}: %s
|
51 |
+
{i18n("Model Author")}: %s
|
52 |
+
{i18n("Information")}: %s
|
53 |
+
{i18n("Sampling rate")}: %s
|
54 |
+
{i18n("Pitch guidance (f0)")}: %s
|
55 |
+
{i18n("Version")}: %s
|
56 |
+
{i18n("ID(short)")}: %s
|
57 |
+
{i18n("ID(long)")}: %s""" % (
|
58 |
+
cpt.get("name", i18n("Unknown")),
|
59 |
+
datetime.fromtimestamp(float(cpt.get("timestamp", 0))),
|
60 |
+
cpt.get("author", i18n("Unknown")),
|
61 |
+
cpt.get("info", i18n("None")),
|
62 |
+
cpt.get("sr", i18n("Unknown")),
|
63 |
+
i18n("Exist") if cpt.get("f0", 0) == 1 else i18n("Not exist"),
|
64 |
+
cpt.get("version", i18n("None")),
|
65 |
+
id,
|
66 |
+
h,
|
67 |
+
)
|
68 |
+
except:
|
69 |
+
txt = traceback.format_exc()
|
70 |
+
|
71 |
+
return txt
|
72 |
+
|
73 |
+
|
74 |
+
def show_info(path):
|
75 |
+
try:
|
76 |
+
if hasattr(path, "name"):
|
77 |
+
path = path.name
|
78 |
+
a = torch.load(path, map_location="cpu")
|
79 |
+
txt = show_model_info(a, show_long_id=True)
|
80 |
+
del a
|
81 |
+
except:
|
82 |
+
txt = traceback.format_exc()
|
83 |
+
|
84 |
+
return txt
|
infer/modules/vc/lgdsng.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:6c2cea46a96b75c3e3154486672c99b335549403f9f61c77ae5bb22854950864
|
3 |
+
size 708982
|
infer/modules/vc/modules.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
import traceback
|
2 |
import logging
|
|
|
3 |
|
4 |
logger = logging.getLogger(__name__)
|
5 |
|
@@ -9,14 +10,10 @@ import torch
|
|
9 |
from io import BytesIO
|
10 |
|
11 |
from infer.lib.audio import load_audio, wav2
|
12 |
-
from
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
SynthesizerTrnMs768NSFsid_nono,
|
17 |
-
)
|
18 |
-
from infer.modules.vc.pipeline import Pipeline
|
19 |
-
from infer.modules.vc.utils import *
|
20 |
|
21 |
|
22 |
class VC:
|
@@ -62,71 +59,45 @@ class VC:
|
|
62 |
) = None
|
63 |
if torch.cuda.is_available():
|
64 |
torch.cuda.empty_cache()
|
|
|
|
|
65 |
###楼下不这么折腾清理不干净
|
|
|
66 |
self.if_f0 = self.cpt.get("f0", 1)
|
67 |
self.version = self.cpt.get("version", "v1")
|
68 |
-
if self.version == "v1":
|
69 |
-
if self.if_f0 == 1:
|
70 |
-
self.net_g = SynthesizerTrnMs256NSFsid(
|
71 |
-
*self.cpt["config"], is_half=self.config.is_half
|
72 |
-
)
|
73 |
-
else:
|
74 |
-
self.net_g = SynthesizerTrnMs256NSFsid_nono(*self.cpt["config"])
|
75 |
-
elif self.version == "v2":
|
76 |
-
if self.if_f0 == 1:
|
77 |
-
self.net_g = SynthesizerTrnMs768NSFsid(
|
78 |
-
*self.cpt["config"], is_half=self.config.is_half
|
79 |
-
)
|
80 |
-
else:
|
81 |
-
self.net_g = SynthesizerTrnMs768NSFsid_nono(*self.cpt["config"])
|
82 |
del self.net_g, self.cpt
|
83 |
if torch.cuda.is_available():
|
84 |
torch.cuda.empty_cache()
|
|
|
|
|
85 |
return (
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
"__type__": "update",
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
},
|
97 |
-
"",
|
98 |
-
"",
|
99 |
)
|
|
|
100 |
person = f'{os.getenv("weight_root")}/{sid}'
|
101 |
logger.info(f"Loading: {person}")
|
102 |
|
103 |
-
self.cpt =
|
104 |
self.tgt_sr = self.cpt["config"][-1]
|
105 |
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
106 |
self.if_f0 = self.cpt.get("f0", 1)
|
107 |
self.version = self.cpt.get("version", "v1")
|
108 |
|
109 |
-
synthesizer_class = {
|
110 |
-
("v1", 1): SynthesizerTrnMs256NSFsid,
|
111 |
-
("v1", 0): SynthesizerTrnMs256NSFsid_nono,
|
112 |
-
("v2", 1): SynthesizerTrnMs768NSFsid,
|
113 |
-
("v2", 0): SynthesizerTrnMs768NSFsid_nono,
|
114 |
-
}
|
115 |
-
|
116 |
-
self.net_g = synthesizer_class.get(
|
117 |
-
(self.version, self.if_f0), SynthesizerTrnMs256NSFsid
|
118 |
-
)(*self.cpt["config"], is_half=self.config.is_half)
|
119 |
-
|
120 |
-
del self.net_g.enc_q
|
121 |
-
|
122 |
-
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
|
123 |
-
self.net_g.eval().to(self.config.device)
|
124 |
if self.config.is_half:
|
125 |
self.net_g = self.net_g.half()
|
126 |
else:
|
127 |
self.net_g = self.net_g.float()
|
128 |
-
|
129 |
self.pipeline = Pipeline(self.tgt_sr, self.config)
|
|
|
130 |
n_spk = self.cpt["config"][-3]
|
131 |
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
|
132 |
logger.info("Select index: " + index["value"])
|
@@ -138,6 +109,7 @@ class VC:
|
|
138 |
to_return_protect1,
|
139 |
index,
|
140 |
index,
|
|
|
141 |
)
|
142 |
if to_return_protect
|
143 |
else {"visible": True, "maximum": n_spk, "__type__": "update"}
|
@@ -160,18 +132,22 @@ class VC:
|
|
160 |
):
|
161 |
if input_audio_path is None:
|
162 |
return "You need to upload an audio", None
|
|
|
|
|
163 |
f0_up_key = int(f0_up_key)
|
164 |
try:
|
165 |
audio = load_audio(input_audio_path, 16000)
|
166 |
audio_max = np.abs(audio).max() / 0.95
|
167 |
if audio_max > 1:
|
168 |
-
audio
|
169 |
times = [0, 0, 0]
|
170 |
|
171 |
if self.hubert_model is None:
|
172 |
-
self.hubert_model = load_hubert(self.config)
|
173 |
|
174 |
if file_index:
|
|
|
|
|
175 |
file_index = (
|
176 |
file_index.strip(" ")
|
177 |
.strip('"')
|
@@ -190,7 +166,6 @@ class VC:
|
|
190 |
self.net_g,
|
191 |
sid,
|
192 |
audio,
|
193 |
-
input_audio_path,
|
194 |
times,
|
195 |
f0_up_key,
|
196 |
f0_method,
|
@@ -204,25 +179,25 @@ class VC:
|
|
204 |
self.version,
|
205 |
protect,
|
206 |
f0_file,
|
207 |
-
)
|
208 |
if self.tgt_sr != resample_sr >= 16000:
|
209 |
tgt_sr = resample_sr
|
210 |
else:
|
211 |
tgt_sr = self.tgt_sr
|
212 |
index_info = (
|
213 |
-
"Index
|
214 |
if os.path.exists(file_index)
|
215 |
else "Index not used."
|
216 |
)
|
217 |
return (
|
218 |
-
"Success.\n%s\nTime
|
219 |
% (index_info, *times),
|
220 |
(tgt_sr, audio_opt),
|
221 |
)
|
222 |
-
except:
|
223 |
info = traceback.format_exc()
|
224 |
logger.warning(info)
|
225 |
-
return
|
226 |
|
227 |
def vc_multi(
|
228 |
self,
|
|
|
1 |
import traceback
|
2 |
import logging
|
3 |
+
import os
|
4 |
|
5 |
logger = logging.getLogger(__name__)
|
6 |
|
|
|
10 |
from io import BytesIO
|
11 |
|
12 |
from infer.lib.audio import load_audio, wav2
|
13 |
+
from rvc.synthesizer import get_synthesizer, load_synthesizer
|
14 |
+
from .info import show_model_info
|
15 |
+
from .pipeline import Pipeline
|
16 |
+
from .utils import get_index_path_from_model, load_hubert
|
|
|
|
|
|
|
|
|
17 |
|
18 |
|
19 |
class VC:
|
|
|
59 |
) = None
|
60 |
if torch.cuda.is_available():
|
61 |
torch.cuda.empty_cache()
|
62 |
+
elif torch.backends.mps.is_available():
|
63 |
+
torch.mps.empty_cache()
|
64 |
###楼下不这么折腾清理不干净
|
65 |
+
self.net_g, self.cpt = get_synthesizer(self.cpt, self.config.device)
|
66 |
self.if_f0 = self.cpt.get("f0", 1)
|
67 |
self.version = self.cpt.get("version", "v1")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
del self.net_g, self.cpt
|
69 |
if torch.cuda.is_available():
|
70 |
torch.cuda.empty_cache()
|
71 |
+
elif torch.backends.mps.is_available():
|
72 |
+
torch.mps.empty_cache()
|
73 |
return (
|
74 |
+
(
|
75 |
+
{"visible": False, "__type__": "update"},
|
76 |
+
to_return_protect0,
|
77 |
+
to_return_protect1,
|
78 |
+
{"value": to_return_protect[2], "__type__": "update"},
|
79 |
+
{"value": to_return_protect[3], "__type__": "update"},
|
80 |
+
{"value": "", "__type__": "update"},
|
81 |
+
)
|
82 |
+
if to_return_protect
|
83 |
+
else {"visible": True, "maximum": 0, "__type__": "update"}
|
|
|
|
|
|
|
84 |
)
|
85 |
+
|
86 |
person = f'{os.getenv("weight_root")}/{sid}'
|
87 |
logger.info(f"Loading: {person}")
|
88 |
|
89 |
+
self.net_g, self.cpt = load_synthesizer(person, self.config.device)
|
90 |
self.tgt_sr = self.cpt["config"][-1]
|
91 |
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0] # n_spk
|
92 |
self.if_f0 = self.cpt.get("f0", 1)
|
93 |
self.version = self.cpt.get("version", "v1")
|
94 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
95 |
if self.config.is_half:
|
96 |
self.net_g = self.net_g.half()
|
97 |
else:
|
98 |
self.net_g = self.net_g.float()
|
|
|
99 |
self.pipeline = Pipeline(self.tgt_sr, self.config)
|
100 |
+
|
101 |
n_spk = self.cpt["config"][-3]
|
102 |
index = {"value": get_index_path_from_model(sid), "__type__": "update"}
|
103 |
logger.info("Select index: " + index["value"])
|
|
|
109 |
to_return_protect1,
|
110 |
index,
|
111 |
index,
|
112 |
+
show_model_info(self.cpt),
|
113 |
)
|
114 |
if to_return_protect
|
115 |
else {"visible": True, "maximum": n_spk, "__type__": "update"}
|
|
|
132 |
):
|
133 |
if input_audio_path is None:
|
134 |
return "You need to upload an audio", None
|
135 |
+
elif hasattr(input_audio_path, "name"):
|
136 |
+
input_audio_path = str(input_audio_path.name)
|
137 |
f0_up_key = int(f0_up_key)
|
138 |
try:
|
139 |
audio = load_audio(input_audio_path, 16000)
|
140 |
audio_max = np.abs(audio).max() / 0.95
|
141 |
if audio_max > 1:
|
142 |
+
np.divide(audio, audio_max, audio)
|
143 |
times = [0, 0, 0]
|
144 |
|
145 |
if self.hubert_model is None:
|
146 |
+
self.hubert_model = load_hubert(self.config.device, self.config.is_half)
|
147 |
|
148 |
if file_index:
|
149 |
+
if hasattr(file_index, "name"):
|
150 |
+
file_index = str(file_index.name)
|
151 |
file_index = (
|
152 |
file_index.strip(" ")
|
153 |
.strip('"')
|
|
|
166 |
self.net_g,
|
167 |
sid,
|
168 |
audio,
|
|
|
169 |
times,
|
170 |
f0_up_key,
|
171 |
f0_method,
|
|
|
179 |
self.version,
|
180 |
protect,
|
181 |
f0_file,
|
182 |
+
).astype(np.int16)
|
183 |
if self.tgt_sr != resample_sr >= 16000:
|
184 |
tgt_sr = resample_sr
|
185 |
else:
|
186 |
tgt_sr = self.tgt_sr
|
187 |
index_info = (
|
188 |
+
"Index: %s." % file_index
|
189 |
if os.path.exists(file_index)
|
190 |
else "Index not used."
|
191 |
)
|
192 |
return (
|
193 |
+
"Success.\n%s\nTime: npy: %.2fs, f0: %.2fs, infer: %.2fs."
|
194 |
% (index_info, *times),
|
195 |
(tgt_sr, audio_opt),
|
196 |
)
|
197 |
+
except Exception as e:
|
198 |
info = traceback.format_exc()
|
199 |
logger.warning(info)
|
200 |
+
return str(e), None
|
201 |
|
202 |
def vc_multi(
|
203 |
self,
|
infer/modules/vc/pipeline.py
CHANGED
@@ -5,40 +5,22 @@ import logging
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
8 |
-
from
|
9 |
-
from time import time as ttime
|
10 |
|
11 |
import faiss
|
12 |
import librosa
|
13 |
import numpy as np
|
14 |
-
import parselmouth
|
15 |
-
import pyworld
|
16 |
import torch
|
17 |
import torch.nn.functional as F
|
18 |
-
import torchcrepe
|
19 |
from scipy import signal
|
20 |
|
|
|
|
|
21 |
now_dir = os.getcwd()
|
22 |
sys.path.append(now_dir)
|
23 |
|
24 |
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
25 |
|
26 |
-
input_audio_path2wav = {}
|
27 |
-
|
28 |
-
|
29 |
-
@lru_cache
|
30 |
-
def cache_harvest_f0(input_audio_path, fs, f0max, f0min, frame_period):
|
31 |
-
audio = input_audio_path2wav[input_audio_path]
|
32 |
-
f0, t = pyworld.harvest(
|
33 |
-
audio,
|
34 |
-
fs=fs,
|
35 |
-
f0_ceil=f0max,
|
36 |
-
f0_floor=f0min,
|
37 |
-
frame_period=frame_period,
|
38 |
-
)
|
39 |
-
f0 = pyworld.stonemask(audio, f0, t, fs)
|
40 |
-
return f0
|
41 |
-
|
42 |
|
43 |
def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比
|
44 |
# print(data1.max(),data2.max())
|
@@ -83,7 +65,6 @@ class Pipeline(object):
|
|
83 |
|
84 |
def get_f0(
|
85 |
self,
|
86 |
-
input_audio_path,
|
87 |
x,
|
88 |
p_len,
|
89 |
f0_up_key,
|
@@ -91,73 +72,62 @@ class Pipeline(object):
|
|
91 |
filter_radius,
|
92 |
inp_f0=None,
|
93 |
):
|
94 |
-
global input_audio_path2wav
|
95 |
-
time_step = self.window / self.sr * 1000
|
96 |
f0_min = 50
|
97 |
f0_max = 1100
|
98 |
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
99 |
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
100 |
if f0_method == "pm":
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
)
|
109 |
-
.selected_array["frequency"]
|
110 |
-
)
|
111 |
-
pad_size = (p_len - len(f0) + 1) // 2
|
112 |
-
if pad_size > 0 or p_len - len(f0) - pad_size > 0:
|
113 |
-
f0 = np.pad(
|
114 |
-
f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant"
|
115 |
-
)
|
116 |
elif f0_method == "harvest":
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
f0 = signal.medfilt(f0, 3)
|
121 |
elif f0_method == "crepe":
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
f0_min,
|
132 |
-
f0_max,
|
133 |
-
model,
|
134 |
-
batch_size=batch_size,
|
135 |
-
device=self.device,
|
136 |
-
return_periodicity=True,
|
137 |
-
)
|
138 |
-
pd = torchcrepe.filter.median(pd, 3)
|
139 |
-
f0 = torchcrepe.filter.mean(f0, 3)
|
140 |
-
f0[pd < 0.1] = 0
|
141 |
-
f0 = f0[0].cpu().numpy()
|
142 |
elif f0_method == "rmvpe":
|
143 |
-
if not hasattr(self, "
|
144 |
-
from infer.lib.rmvpe import RMVPE
|
145 |
-
|
146 |
logger.info(
|
147 |
-
"Loading rmvpe model
|
148 |
)
|
149 |
-
self.
|
150 |
"%s/rmvpe.pt" % os.environ["rmvpe_root"],
|
151 |
is_half=self.is_half,
|
152 |
device=self.device,
|
|
|
153 |
)
|
154 |
-
f0 = self.
|
155 |
|
156 |
if "privateuseone" in str(self.device): # clean ortruntime memory
|
157 |
-
del self.
|
158 |
-
del self.
|
159 |
logger.info("Cleaning ortruntime memory")
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
f0 *= pow(2, f0_up_key / 12)
|
162 |
# with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
|
163 |
tf0 = self.sr // self.window # 每秒f0点数
|
@@ -214,7 +184,7 @@ class Pipeline(object):
|
|
214 |
"padding_mask": padding_mask,
|
215 |
"output_layer": 9 if version == "v1" else 12,
|
216 |
}
|
217 |
-
t0 =
|
218 |
with torch.no_grad():
|
219 |
logits = model.extract_features(**inputs)
|
220 |
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
@@ -232,7 +202,10 @@ class Pipeline(object):
|
|
232 |
# _, I = index.search(npy, 1)
|
233 |
# npy = big_npy[I.squeeze()]
|
234 |
|
235 |
-
|
|
|
|
|
|
|
236 |
weight = np.square(1 / score)
|
237 |
weight /= weight.sum(axis=1, keepdims=True)
|
238 |
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
@@ -249,7 +222,7 @@ class Pipeline(object):
|
|
249 |
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
|
250 |
0, 2, 1
|
251 |
)
|
252 |
-
t1 =
|
253 |
p_len = audio0.shape[0] // self.window
|
254 |
if feats.shape[1] < p_len:
|
255 |
p_len = feats.shape[1]
|
@@ -266,14 +239,26 @@ class Pipeline(object):
|
|
266 |
feats = feats.to(feats0.dtype)
|
267 |
p_len = torch.tensor([p_len], device=self.device).long()
|
268 |
with torch.no_grad():
|
269 |
-
|
270 |
-
|
271 |
-
|
272 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
273 |
del feats, p_len, padding_mask
|
274 |
if torch.cuda.is_available():
|
275 |
torch.cuda.empty_cache()
|
276 |
-
|
|
|
|
|
277 |
times[0] += t1 - t0
|
278 |
times[2] += t2 - t1
|
279 |
return audio1
|
@@ -284,7 +269,6 @@ class Pipeline(object):
|
|
284 |
net_g,
|
285 |
sid,
|
286 |
audio,
|
287 |
-
input_audio_path,
|
288 |
times,
|
289 |
f0_up_key,
|
290 |
f0_method,
|
@@ -308,7 +292,6 @@ class Pipeline(object):
|
|
308 |
):
|
309 |
try:
|
310 |
index = faiss.read_index(file_index)
|
311 |
-
# big_npy = np.load(file_big_npy)
|
312 |
big_npy = index.reconstruct_n(0, index.ntotal)
|
313 |
except:
|
314 |
traceback.print_exc()
|
@@ -334,7 +317,7 @@ class Pipeline(object):
|
|
334 |
s = 0
|
335 |
audio_opt = []
|
336 |
t = None
|
337 |
-
t1 =
|
338 |
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
339 |
p_len = audio_pad.shape[0] // self.window
|
340 |
inp_f0 = None
|
@@ -350,27 +333,29 @@ class Pipeline(object):
|
|
350 |
traceback.print_exc()
|
351 |
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
352 |
pitch, pitchf = None, None
|
353 |
-
if if_f0
|
354 |
-
|
355 |
-
|
356 |
-
|
357 |
-
|
358 |
-
|
359 |
-
|
360 |
-
|
361 |
-
|
362 |
-
|
|
|
|
|
363 |
pitch = pitch[:p_len]
|
364 |
pitchf = pitchf[:p_len]
|
365 |
if "mps" not in str(self.device) or "xpu" not in str(self.device):
|
366 |
pitchf = pitchf.astype(np.float32)
|
367 |
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
368 |
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
369 |
-
t2 =
|
370 |
times[1] += t2 - t1
|
371 |
for t in opt_ts:
|
372 |
t = t // self.window * self.window
|
373 |
-
if if_f0
|
374 |
audio_opt.append(
|
375 |
self.vc(
|
376 |
model,
|
@@ -405,7 +390,7 @@ class Pipeline(object):
|
|
405 |
)[self.t_pad_tgt : -self.t_pad_tgt]
|
406 |
)
|
407 |
s = t
|
408 |
-
if if_f0
|
409 |
audio_opt.append(
|
410 |
self.vc(
|
411 |
model,
|
@@ -450,8 +435,10 @@ class Pipeline(object):
|
|
450 |
max_int16 = 32768
|
451 |
if audio_max > 1:
|
452 |
max_int16 /= audio_max
|
453 |
-
audio_opt
|
454 |
del pitch, pitchf, sid
|
455 |
if torch.cuda.is_available():
|
456 |
torch.cuda.empty_cache()
|
|
|
|
|
457 |
return audio_opt
|
|
|
5 |
|
6 |
logger = logging.getLogger(__name__)
|
7 |
|
8 |
+
from time import time
|
|
|
9 |
|
10 |
import faiss
|
11 |
import librosa
|
12 |
import numpy as np
|
|
|
|
|
13 |
import torch
|
14 |
import torch.nn.functional as F
|
|
|
15 |
from scipy import signal
|
16 |
|
17 |
+
from rvc.f0 import PM, Harvest, RMVPE, CRePE, Dio, FCPE
|
18 |
+
|
19 |
now_dir = os.getcwd()
|
20 |
sys.path.append(now_dir)
|
21 |
|
22 |
bh, ah = signal.butter(N=5, Wn=48, btype="high", fs=16000)
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
def change_rms(data1, sr1, data2, sr2, rate): # 1是输入音频,2是输出音频,rate是2的占比
|
26 |
# print(data1.max(),data2.max())
|
|
|
65 |
|
66 |
def get_f0(
|
67 |
self,
|
|
|
68 |
x,
|
69 |
p_len,
|
70 |
f0_up_key,
|
|
|
72 |
filter_radius,
|
73 |
inp_f0=None,
|
74 |
):
|
|
|
|
|
75 |
f0_min = 50
|
76 |
f0_max = 1100
|
77 |
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
|
78 |
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
|
79 |
if f0_method == "pm":
|
80 |
+
if not hasattr(self, "pm"):
|
81 |
+
self.pm = PM(self.window, f0_min, f0_max, self.sr)
|
82 |
+
f0 = self.pm.compute_f0(x, p_len=p_len)
|
83 |
+
if f0_method == "dio":
|
84 |
+
if not hasattr(self, "dio"):
|
85 |
+
self.dio = Dio(self.window, f0_min, f0_max, self.sr)
|
86 |
+
f0 = self.dio.compute_f0(x, p_len=p_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
87 |
elif f0_method == "harvest":
|
88 |
+
if not hasattr(self, "harvest"):
|
89 |
+
self.harvest = Harvest(self.window, f0_min, f0_max, self.sr)
|
90 |
+
f0 = self.harvest.compute_f0(x, p_len=p_len, filter_radius=filter_radius)
|
|
|
91 |
elif f0_method == "crepe":
|
92 |
+
if not hasattr(self, "crepe"):
|
93 |
+
self.crepe = CRePE(
|
94 |
+
self.window,
|
95 |
+
f0_min,
|
96 |
+
f0_max,
|
97 |
+
self.sr,
|
98 |
+
self.device,
|
99 |
+
)
|
100 |
+
f0 = self.crepe.compute_f0(x, p_len=p_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
101 |
elif f0_method == "rmvpe":
|
102 |
+
if not hasattr(self, "rmvpe"):
|
|
|
|
|
103 |
logger.info(
|
104 |
+
"Loading rmvpe model %s" % "%s/rmvpe.pt" % os.environ["rmvpe_root"]
|
105 |
)
|
106 |
+
self.rmvpe = RMVPE(
|
107 |
"%s/rmvpe.pt" % os.environ["rmvpe_root"],
|
108 |
is_half=self.is_half,
|
109 |
device=self.device,
|
110 |
+
# use_jit=self.config.use_jit,
|
111 |
)
|
112 |
+
f0 = self.rmvpe.compute_f0(x, p_len=p_len, filter_radius=0.03)
|
113 |
|
114 |
if "privateuseone" in str(self.device): # clean ortruntime memory
|
115 |
+
del self.rmvpe.model
|
116 |
+
del self.rmvpe
|
117 |
logger.info("Cleaning ortruntime memory")
|
118 |
|
119 |
+
elif f0_method == "fcpe":
|
120 |
+
if not hasattr(self, "model_fcpe"):
|
121 |
+
logger.info("Loading fcpe model")
|
122 |
+
self.model_fcpe = FCPE(
|
123 |
+
self.window,
|
124 |
+
f0_min,
|
125 |
+
f0_max,
|
126 |
+
self.sr,
|
127 |
+
self.device,
|
128 |
+
)
|
129 |
+
f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
|
130 |
+
|
131 |
f0 *= pow(2, f0_up_key / 12)
|
132 |
# with open("test.txt","w")as f:f.write("\n".join([str(i)for i in f0.tolist()]))
|
133 |
tf0 = self.sr // self.window # 每秒f0点数
|
|
|
184 |
"padding_mask": padding_mask,
|
185 |
"output_layer": 9 if version == "v1" else 12,
|
186 |
}
|
187 |
+
t0 = time()
|
188 |
with torch.no_grad():
|
189 |
logits = model.extract_features(**inputs)
|
190 |
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
|
|
202 |
# _, I = index.search(npy, 1)
|
203 |
# npy = big_npy[I.squeeze()]
|
204 |
|
205 |
+
try:
|
206 |
+
score, ix = index.search(npy, k=8)
|
207 |
+
except:
|
208 |
+
raise Exception("index mistatch")
|
209 |
weight = np.square(1 / score)
|
210 |
weight /= weight.sum(axis=1, keepdims=True)
|
211 |
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
|
|
222 |
feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(
|
223 |
0, 2, 1
|
224 |
)
|
225 |
+
t1 = time()
|
226 |
p_len = audio0.shape[0] // self.window
|
227 |
if feats.shape[1] < p_len:
|
228 |
p_len = feats.shape[1]
|
|
|
239 |
feats = feats.to(feats0.dtype)
|
240 |
p_len = torch.tensor([p_len], device=self.device).long()
|
241 |
with torch.no_grad():
|
242 |
+
audio1 = (
|
243 |
+
(
|
244 |
+
net_g.infer(
|
245 |
+
feats,
|
246 |
+
p_len,
|
247 |
+
sid,
|
248 |
+
pitch=pitch,
|
249 |
+
pitchf=pitchf,
|
250 |
+
)[0, 0]
|
251 |
+
)
|
252 |
+
.data.cpu()
|
253 |
+
.float()
|
254 |
+
.numpy()
|
255 |
+
)
|
256 |
del feats, p_len, padding_mask
|
257 |
if torch.cuda.is_available():
|
258 |
torch.cuda.empty_cache()
|
259 |
+
elif torch.backends.mps.is_available():
|
260 |
+
torch.mps.empty_cache()
|
261 |
+
t2 = time()
|
262 |
times[0] += t1 - t0
|
263 |
times[2] += t2 - t1
|
264 |
return audio1
|
|
|
269 |
net_g,
|
270 |
sid,
|
271 |
audio,
|
|
|
272 |
times,
|
273 |
f0_up_key,
|
274 |
f0_method,
|
|
|
292 |
):
|
293 |
try:
|
294 |
index = faiss.read_index(file_index)
|
|
|
295 |
big_npy = index.reconstruct_n(0, index.ntotal)
|
296 |
except:
|
297 |
traceback.print_exc()
|
|
|
317 |
s = 0
|
318 |
audio_opt = []
|
319 |
t = None
|
320 |
+
t1 = time()
|
321 |
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
322 |
p_len = audio_pad.shape[0] // self.window
|
323 |
inp_f0 = None
|
|
|
333 |
traceback.print_exc()
|
334 |
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
335 |
pitch, pitchf = None, None
|
336 |
+
if if_f0:
|
337 |
+
if if_f0 == 1:
|
338 |
+
pitch, pitchf = self.get_f0(
|
339 |
+
audio_pad,
|
340 |
+
p_len,
|
341 |
+
f0_up_key,
|
342 |
+
f0_method,
|
343 |
+
filter_radius,
|
344 |
+
inp_f0,
|
345 |
+
)
|
346 |
+
elif if_f0 == 2:
|
347 |
+
pitch, pitchf = f0_method
|
348 |
pitch = pitch[:p_len]
|
349 |
pitchf = pitchf[:p_len]
|
350 |
if "mps" not in str(self.device) or "xpu" not in str(self.device):
|
351 |
pitchf = pitchf.astype(np.float32)
|
352 |
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
353 |
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
354 |
+
t2 = time()
|
355 |
times[1] += t2 - t1
|
356 |
for t in opt_ts:
|
357 |
t = t // self.window * self.window
|
358 |
+
if if_f0:
|
359 |
audio_opt.append(
|
360 |
self.vc(
|
361 |
model,
|
|
|
390 |
)[self.t_pad_tgt : -self.t_pad_tgt]
|
391 |
)
|
392 |
s = t
|
393 |
+
if if_f0:
|
394 |
audio_opt.append(
|
395 |
self.vc(
|
396 |
model,
|
|
|
435 |
max_int16 = 32768
|
436 |
if audio_max > 1:
|
437 |
max_int16 /= audio_max
|
438 |
+
np.multiply(audio_opt, max_int16, audio_opt)
|
439 |
del pitch, pitchf, sid
|
440 |
if torch.cuda.is_available():
|
441 |
torch.cuda.empty_cache()
|
442 |
+
elif torch.backends.mps.is_available():
|
443 |
+
torch.mps.empty_cache()
|
444 |
return audio_opt
|
infer/modules/vc/utils.py
CHANGED
@@ -9,7 +9,8 @@ def get_index_path_from_model(sid):
|
|
9 |
f
|
10 |
for f in [
|
11 |
os.path.join(root, name)
|
12 |
-
for
|
|
|
13 |
for name in files
|
14 |
if name.endswith(".index") and "trained" not in name
|
15 |
]
|
@@ -19,14 +20,14 @@ def get_index_path_from_model(sid):
|
|
19 |
)
|
20 |
|
21 |
|
22 |
-
def load_hubert(
|
23 |
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
24 |
["assets/hubert/hubert_base.pt"],
|
25 |
suffix="",
|
26 |
)
|
27 |
hubert_model = models[0]
|
28 |
-
hubert_model = hubert_model.to(
|
29 |
-
if
|
30 |
hubert_model = hubert_model.half()
|
31 |
else:
|
32 |
hubert_model = hubert_model.float()
|
|
|
9 |
f
|
10 |
for f in [
|
11 |
os.path.join(root, name)
|
12 |
+
for path in [os.getenv("outside_index_root"), os.getenv("index_root")]
|
13 |
+
for root, _, files in os.walk(path, topdown=False)
|
14 |
for name in files
|
15 |
if name.endswith(".index") and "trained" not in name
|
16 |
]
|
|
|
20 |
)
|
21 |
|
22 |
|
23 |
+
def load_hubert(device, is_half):
|
24 |
models, _, _ = checkpoint_utils.load_model_ensemble_and_task(
|
25 |
["assets/hubert/hubert_base.pt"],
|
26 |
suffix="",
|
27 |
)
|
28 |
hubert_model = models[0]
|
29 |
+
hubert_model = hubert_model.to(device)
|
30 |
+
if is_half:
|
31 |
hubert_model = hubert_model.half()
|
32 |
else:
|
33 |
hubert_model = hubert_model.float()
|