ZeyuXie commited on
Commit
6a59bc1
1 Parent(s): 518b15d

Update pico_model.py

Browse files
Files changed (1) hide show
  1. pico_model.py +1 -31
pico_model.py CHANGED
@@ -12,36 +12,6 @@ from audioldm.audio.stft import TacotronSTFT
12
  from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
13
  from audioldm.utils import default_audioldm_config, get_metadata
14
 
15
-
16
-
17
- def build_pretrained_models(name):
18
- checkpoint = torch.load(get_metadata()[name]["path"], map_location="cpu")
19
- scale_factor = checkpoint["state_dict"]["scale_factor"].item()
20
-
21
- vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k}
22
-
23
- config = default_audioldm_config(name)
24
- vae_config = config["model"]["params"]["first_stage_config"]["params"]
25
- vae_config["scale_factor"] = scale_factor
26
-
27
- vae = AutoencoderKL(**vae_config)
28
- vae.load_state_dict(vae_state_dict)
29
-
30
- fn_STFT = TacotronSTFT(
31
- config["preprocessing"]["stft"]["filter_length"],
32
- config["preprocessing"]["stft"]["hop_length"],
33
- config["preprocessing"]["stft"]["win_length"],
34
- config["preprocessing"]["mel"]["n_mel_channels"],
35
- config["preprocessing"]["audio"]["sampling_rate"],
36
- config["preprocessing"]["mel"]["mel_fmin"],
37
- config["preprocessing"]["mel"]["mel_fmax"],
38
- )
39
-
40
- vae.eval()
41
- fn_STFT.eval()
42
-
43
- return vae, fn_STFT
44
-
45
  def _init_layer(layer):
46
  """Initialize a Linear or Convolutional layer. """
47
  nn.init.xavier_uniform_(layer.weight)
@@ -260,7 +230,7 @@ class PicoDiffusion(ClapText_Onset_2_Audio_Diffusion):
260
  ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
261
  del_parameter_key = ["text_branch.embeddings.position_ids"]
262
  ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
263
- diffusion_ckpt = torch.load(diffusion_pt)
264
  del diffusion_ckpt["class_emb.weight"]
265
  ckpt.update(diffusion_ckpt)
266
  self.load_state_dict(ckpt)
 
12
  from audioldm.variational_autoencoder.autoencoder import AutoencoderKL
13
  from audioldm.utils import default_audioldm_config, get_metadata
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def _init_layer(layer):
16
  """Initialize a Linear or Convolutional layer. """
17
  nn.init.xavier_uniform_(layer.weight)
 
230
  ckpt = clap_load_state_dict(freeze_text_encoder_ckpt, skip_params=True)
231
  del_parameter_key = ["text_branch.embeddings.position_ids"]
232
  ckpt = {f"freeze_text_encoder.model.{k}":v for k, v in ckpt.items() if k not in del_parameter_key}
233
+ diffusion_ckpt = torch.load(diffusion_pt, map_location=torch.device(self.device))
234
  del diffusion_ckpt["class_emb.weight"]
235
  ckpt.update(diffusion_ckpt)
236
  self.load_state_dict(ckpt)