VAE is a scriptmodule

#1
by sassoshots - opened

Hello! Cool project. The vae seems to be a script module instead of state dict? none of the functions seem to have been exported. When can we expect the actual state dict?

figured it out, for anyone running into the same issue trying to utilize the VAE, here is a solution using stable audio tools

    vae = torch.load("vae_model.pt").state_dict()
    pretransform_ckpt_path = vae
    torch.save(pretransform_ckpt_path, "new_vae.ckpt")
    model.pretransform.load_state_dict(torch.load("new_vae.ckpt"), strict=False)

I also am not seeing great reconstruction compared to my in house VAE for those curious. Seems to distort 808s and bass components and sounds thin overall

Same problem here, i had to preprocess before encoding the dataset to a lower volume to don't get a distorted reconstruction, and compared to Stable Audio VAE checkpoint there is almost no audible difference with the same songs. When generating music from the model on hugging face the quality is way better. No clue of what is the problem.

The current version of vae is exported using torch.jit.script, and can be used in the following way:

vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
vae = torch.jit.load(vae_ckpt_path, map_location='cpu').to(device)
wav = vae.decode_export(latent)

The input audio should be normalized to -6db. The main difference between our VAE and stable audio versions is that our VAE is robust against MP3 compression artifacts
image.png

The current version of vae is exported using torch.jit.script, and can be used in the following way:

vae_ckpt_path = hf_hub_download(repo_id="ASLP-lab/DiffRhythm-vae", filename="vae_model.pt")
vae = torch.jit.load(vae_ckpt_path, map_location='cpu').to(device)
wav = vae.decode_export(latent)

The input audio should be normalized to -6db. The main difference between our VAE and stable audio versions is that our VAE is robust against MP3 compression artifacts
image.png

Did you consider how you it might influence latents which have full resolution ? (Either encoded or from a generation). Did you include a distribution of lossless and lossy latents for your VAE training or was it only lossy ? Could explain the issues I see with the decoder

Sign up or log in to comment