Spaces:
Running
on
Zero
Running
on
Zero
Update pico_model.py
Browse files- pico_model.py +1 -31
pico_model.py
CHANGED
@@ -12,36 +12,6 @@ from audioldm.audio.stft import TacotronSTFT
|
|
12 |
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
|
13 |
from audioldm.utils import default_audioldm_config, get_metadata
|
14 |
|
15 |
-
|
16 |
-
|
17 |
-
def build_pretrained_models(name):
|
18 |
-
checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
|
19 |
-
scale_factor = checkpoint["state_dict"]["scale_factor"].item()
|
20 |
-
|
21 |
-
vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
|
22 |
-
|
23 |
-
config = default_audioldm_config(name)
|
24 |
-
vae_config = config["model"]["params"]["first_stage_config"]["params"]
|
25 |
-
vae_config["scale_factor"] = scale_factor
|
26 |
-
|
27 |
-
vae = AutoencoderKL(**vae_config)
|
28 |
-
vae.load_state_dict(vae_state_dict)
|
29 |
-
|
30 |
-
fn_STFT = TacotronSTFT(
|
31 |
-
config["preprocessing"]["stft"]["filter_length"],
|
32 |
-
config["preprocessing"]["stft"]["hop_length"],
|
33 |
-
config["preprocessing"]["stft"]["win_length"],
|
34 |
-
config["preprocessing"]["mel"]["n_mel_channels"],
|
35 |
-
config["preprocessing"]["audio"]["sampling_rate"],
|
36 |
-
config["preprocessing"]["mel"]["mel_fmin"],
|
37 |
-
config["preprocessing"]["mel"]["mel_fmax"],
|
38 |
-
)
|
39 |
-
|
40 |
-
vae.eval()
|
41 |
-
fn_STFT.eval()
|
42 |
-
|
43 |
-
return vae, fn_STFT
|
44 |
-
|
45 |
def _init_layer(layer):
|
46 |
"""Initialize a Linear or Convolutional layer. """
|
47 |
nn.init.xavier_uniform_(layer.weight)
|
@@ -260,7 +230,7 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
|
|
260 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
261 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
262 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
263 |
-
diffusion_ckpt = torch.load(diffusion_pt)
|
264 |
del diffusion_ckpt["class_emb.weight"]
|
265 |
ckpt.update(diffusion_ckpt)
|
266 |
self.load_state_dict(ckpt)
|
|
|
12 |
from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
|
13 |
from audioldm.utils import default_audioldm_config, get_metadata
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
def _init_layer(layer):
|
16 |
"""Initialize a Linear or Convolutional layer. """
|
17 |
nn.init.xavier_uniform_(layer.weight)
|
|
|
230 |
ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
|
231 |
del_parameter_key = ["text_branch.embeddings.position_ids"]
|
232 |
ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
|
233 |
+
diffusion_ckpt = torch.load(diffusion_pt, map_location=torch.device(self.device))
|
234 |
del diffusion_ckpt["class_emb.weight"]
|
235 |
ckpt.update(diffusion_ckpt)
|
236 |
self.load_state_dict(ckpt)
|