import os import yaml import torch import torchaudio from audioldm2.latent_diffusion.models.ddpm import LatentDiffusion from audioldm2.utils import default_audioldm_config, get_metadata, download_checkpoint from audioldm2.utilities.audio import read_wav_file import os CACHE_DIR = os.getenv( "AUDIOLDM_CACHE_DIR", os.path.join(os.path.expanduser("~"), ".cache/audioldm2") ) def seed_everything(seed): import random, os import numpy as np import torch random.seed(seed) os.environ["PYTHONHASHSEED"] = str(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed(seed) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = True def text_to_filename(text): return text.replace(" ", "_").replace("'", "_").replace('"', "_") def extract_kaldi_fbank_feature(waveform, sampling_rate, log_mel_spec): norm_mean = -4.2677393 norm_std = 4.5689974 if sampling_rate != 16000: waveform_16k = torchaudio.functional.resample( waveform, orig_freq=sampling_rate, new_freq=16000 ) else: waveform_16k = waveform waveform_16k = waveform_16k - waveform_16k.mean() fbank = torchaudio.compliance.kaldi.fbank( waveform_16k, htk_compat=True, sample_frequency=16000, use_energy=False, window_type="hanning", num_mel_bins=128, dither=0.0, frame_shift=10, ) TARGET_LEN = log_mel_spec.size(0) # cut and pad n_frames = fbank.shape[0] p = TARGET_LEN - n_frames if p > 0: m = torch.nn.ZeroPad2d((0, 0, 0, p)) fbank = m(fbank) elif p < 0: fbank = fbank[:TARGET_LEN, :] fbank = (fbank - norm_mean) / (norm_std * 2) return {"ta_kaldi_fbank": fbank} # [1024, 128] def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): text = [text] * batchsize if batchsize < 1: print("Warning: Batchsize must be at least 1. Batchsize is set to .") if fbank is None: fbank = torch.zeros( (batchsize, 1024, 64) ) # Not used, here to keep the code format else: fbank = torch.FloatTensor(fbank) fbank = fbank.expand(batchsize, 1024, 64) assert fbank.size(0) == batchsize stft = torch.zeros((batchsize, 1024, 512)) # Not used if waveform is None: waveform = torch.zeros((batchsize, 160000)) # Not used ta_kaldi_fbank = torch.zeros((batchsize, 1024, 128)) else: waveform = torch.FloatTensor(waveform) waveform = waveform.expand(batchsize, -1) assert waveform.size(0) == batchsize ta_kaldi_fbank = extract_kaldi_fbank_feature(waveform, 16000, fbank) batch = { "text": text, # list "fname": [text_to_filename(t) for t in text], # list "waveform": waveform, "stft": stft, "log_mel_spec": fbank, "ta_kaldi_fbank": ta_kaldi_fbank, } return batch def round_up_duration(duration): return int(round(duration / 2.5) + 1) * 2.5 def split_clap_weight_to_pth(checkpoint): if os.path.exists(os.path.join(CACHE_DIR, "clap.pth")): return print("Constructing the weight for the CLAP model.") include_keys = "cond_stage_models.0.cond_stage_models.0.model." new_state_dict = {} for each in checkpoint["state_dict"].keys(): if include_keys in each: new_state_dict[each.replace(include_keys, "module.")] = checkpoint[ "state_dict" ][each] torch.save({"state_dict": new_state_dict}, os.path.join(CACHE_DIR, "clap.pth")) def build_model(ckpt_path=None, config=None, model_name="audioldm2-full"): print("Loading AudioLDM-2: %s" % model_name) if ckpt_path is None: ckpt_path = get_metadata()[model_name]["path"] if not os.path.exists(ckpt_path): download_checkpoint(model_name) if torch.cuda.is_available(): device = torch.device("cuda:0") else: device = torch.device("cpu") if config is not None: assert type(config) is str config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) else: config = default_audioldm_config(model_name) # # Use text as condition instead of using waveform during training config["model"]["params"]["device"] = device # config["model"]["params"]["cond_stage_key"] = "text" # No normalization here latent_diffusion = LatentDiffusion(**config["model"]["params"]) resume_from_checkpoint = ckpt_path checkpoint = torch.load(resume_from_checkpoint, map_location=device) latent_diffusion.load_state_dict(checkpoint["state_dict"]) latent_diffusion.eval() latent_diffusion = latent_diffusion.to(device) return latent_diffusion def duration_to_latent_t_size(duration): return int(duration * 25.6) def text_to_audio( latent_diffusion, text, seed=42, ddim_steps=200, duration=10, batchsize=1, guidance_scale=3.5, n_candidate_gen_per_text=3, config=None, ): assert ( duration == 10 ), "Error: Currently we only support 10 seconds of generation. Generating longer files requires some extra coding, which would be a part of the future work." seed_everything(int(seed)) waveform = None batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) with torch.no_grad(): waveform = latent_diffusion.generate_batch( batch, unconditional_guidance_scale=guidance_scale, ddim_steps=ddim_steps, n_gen=n_candidate_gen_per_text, duration=duration, ) return waveform