diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..4308f65163299f6897c26b42c03b7751e80fec9d 100644 --- a/.gitattributes +++ b/.gitattributes @@ -33,3 +33,6 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text *.zip filter=lfs diff=lfs merge=lfs -text *.zst filter=lfs diff=lfs merge=lfs -text *tfevents* filter=lfs diff=lfs merge=lfs -text +data/-26aVYRtEAc_000030.mp4 filter=lfs diff=lfs merge=lfs -text +data/-yoaSondvkw_000071.mp4 filter=lfs diff=lfs merge=lfs -text +data/0Bp8c3PfAAA_000053.mp4 filter=lfs diff=lfs merge=lfs -text diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/audioldm/__init__.py b/audioldm/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2bbf85f01ccc72b6f18e7405d940adf07a26b500 --- /dev/null +++ b/audioldm/__init__.py @@ -0,0 +1,8 @@ +from .ldm import LatentDiffusion +from .utils import seed_everything, save_wave, get_time, get_duration +from .pipeline import * + + + + + diff --git a/audioldm/__main__.py b/audioldm/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..13f8bafa839f512a156dd6380d2cf43c573a970a --- /dev/null +++ b/audioldm/__main__.py @@ -0,0 +1,183 @@ +#!/usr/bin/python3 +import os +from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration +import argparse + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm")) + +parser = argparse.ArgumentParser() + +parser.add_argument( + "--mode", + type=str, + required=False, + default="generation", + help="generation: text-to-audio generation; transfer: style transfer", + choices=["generation", "transfer"] +) + +parser.add_argument( + "-t", + "--text", + type=str, + required=False, + default="", + help="Text prompt to the model for audio generation", +) + +parser.add_argument( + "-f", + "--file_path", + type=str, + required=False, + default=None, + help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio", +) + +parser.add_argument( + "--transfer_strength", + type=float, + required=False, + default=0.5, + help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text", +) + +parser.add_argument( + "-s", + "--save_path", + type=str, + required=False, + help="The path to save model output", + default="./output", +) + +parser.add_argument( + "--model_name", + type=str, + required=False, + help="The checkpoint you gonna use", + default="audioldm-s-full", + choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"] +) + +parser.add_argument( + "-ckpt", + "--ckpt_path", + type=str, + required=False, + help="The path to the pretrained .ckpt model", + default=None, +) + +parser.add_argument( + "-b", + "--batchsize", + type=int, + required=False, + default=1, + help="Generate how many samples at the same time", +) + +parser.add_argument( + "--ddim_steps", + type=int, + required=False, + default=200, + help="The sampling step for DDIM", +) + +parser.add_argument( + "-gs", + "--guidance_scale", + type=float, + required=False, + default=2.5, + help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)", +) + +parser.add_argument( + "-dur", + "--duration", + type=float, + required=False, + default=10.0, + help="The duration of the samples", +) + +parser.add_argument( + "-n", + "--n_candidate_gen_per_text", + type=int, + required=False, + default=3, + help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation", +) + +parser.add_argument( + "--seed", + type=int, + required=False, + default=42, + help="Change this value (any integer number) will lead to a different generation result.", +) + +args = parser.parse_args() + +if(args.ckpt_path is not None): + print("Warning: ckpt_path has no effect after version 0.0.20.") + +assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5" + +mode = args.mode +if(mode == "generation" and args.file_path is not None): + mode = "generation_audio_to_audio" + if(len(args.text) > 0): + print("Warning: You have specified the --file_path. --text will be ignored") + args.text = "" + +save_path = os.path.join(args.save_path, mode) + +if(args.file_path is not None): + save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0])) + +text = args.text +random_seed = args.seed +duration = args.duration +guidance_scale = args.guidance_scale +n_candidate_gen_per_text = args.n_candidate_gen_per_text + +os.makedirs(save_path, exist_ok=True) +audioldm = build_model(model_name=args.model_name) + +if(args.mode == "generation"): + waveform = text_to_audio( + audioldm, + text, + args.file_path, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + batchsize=args.batchsize, + ) + +elif(args.mode == "transfer"): + assert args.file_path is not None + assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path + waveform = style_transfer( + audioldm, + text, + args.file_path, + args.transfer_strength, + random_seed, + duration=duration, + guidance_scale=guidance_scale, + ddim_steps=args.ddim_steps, + batchsize=args.batchsize, + ) + waveform = waveform[:,None,:] + +save_wave(waveform, save_path, name="%s_%s" % (get_time(), text)) diff --git a/audioldm/__pycache__/__init__.cpython-310.pyc b/audioldm/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ac9273f78f11206a8c017b4f96e7a53ff21306fe Binary files /dev/null and b/audioldm/__pycache__/__init__.cpython-310.pyc differ diff --git a/audioldm/__pycache__/ldm.cpython-310.pyc b/audioldm/__pycache__/ldm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..19e3f57d461c628596414c4649178a78956d528c Binary files /dev/null and b/audioldm/__pycache__/ldm.cpython-310.pyc differ diff --git a/audioldm/__pycache__/pipeline.cpython-310.pyc b/audioldm/__pycache__/pipeline.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e77eaf5512d112fd243abf9de49350464ef02c84 Binary files /dev/null and b/audioldm/__pycache__/pipeline.cpython-310.pyc differ diff --git a/audioldm/__pycache__/utils.cpython-310.pyc b/audioldm/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6dad647ef842d2ebcd3db459d9d93e0401051fec Binary files /dev/null and b/audioldm/__pycache__/utils.cpython-310.pyc differ diff --git a/audioldm/audio/__init__.py b/audioldm/audio/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..56902e96f041bc4ba6bfadd7a7742023b9560233 --- /dev/null +++ b/audioldm/audio/__init__.py @@ -0,0 +1,2 @@ +from .tools import wav_to_fbank, read_wav_file +from .stft import TacotronSTFT diff --git a/audioldm/audio/__pycache__/__init__.cpython-310.pyc b/audioldm/audio/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..99ba0e04045de588109eb23673460f6be2935788 Binary files /dev/null and b/audioldm/audio/__pycache__/__init__.cpython-310.pyc differ diff --git a/audioldm/audio/__pycache__/audio_processing.cpython-310.pyc b/audioldm/audio/__pycache__/audio_processing.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cfab7d91d7a9672de32e6e46ff5b63863b33c98f Binary files /dev/null and b/audioldm/audio/__pycache__/audio_processing.cpython-310.pyc differ diff --git a/audioldm/audio/__pycache__/stft.cpython-310.pyc b/audioldm/audio/__pycache__/stft.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75f73f5f2539e8e5bf99e3bac610ecfb23e44aaf Binary files /dev/null and b/audioldm/audio/__pycache__/stft.cpython-310.pyc differ diff --git a/audioldm/audio/__pycache__/tools.cpython-310.pyc b/audioldm/audio/__pycache__/tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c652219cfa1b421884a361ce03f522c49e09385c Binary files /dev/null and b/audioldm/audio/__pycache__/tools.cpython-310.pyc differ diff --git a/audioldm/audio/audio_processing.py b/audioldm/audio/audio_processing.py new file mode 100644 index 0000000000000000000000000000000000000000..77a4057aa82f226f68474f4c2a19eba84510d663 --- /dev/null +++ b/audioldm/audio/audio_processing.py @@ -0,0 +1,100 @@ +import torch +import numpy as np +import librosa.util as librosa_util +from scipy.signal import get_window + + +def window_sumsquare( + window, + n_frames, + hop_length, + win_length, + n_fft, + dtype=np.float32, + norm=None, +): + """ + # from librosa 0.6 + Compute the sum-square envelope of a window function at a given hop length. + + This is used to estimate modulation effects induced by windowing + observations in short-time fourier transforms. + + Parameters + ---------- + window : string, tuple, number, callable, or list-like + Window specification, as in `get_window` + + n_frames : int > 0 + The number of analysis frames + + hop_length : int > 0 + The number of samples to advance between frames + + win_length : [optional] + The length of the window function. By default, this matches `n_fft`. + + n_fft : int > 0 + The length of each analysis frame. + + dtype : np.dtype + The data type of the output + + Returns + ------- + wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))` + The sum-squared envelope of the window function + """ + if win_length is None: + win_length = n_fft + + n = n_fft + hop_length * (n_frames - 1) + x = np.zeros(n, dtype=dtype) + + # Compute the squared window at the desired length + win_sq = get_window(window, win_length, fftbins=True) + win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2 + win_sq = librosa_util.pad_center(win_sq, n_fft) + + # Fill the envelope + for i in range(n_frames): + sample = i * hop_length + x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))] + return x + + +def griffin_lim(magnitudes, stft_fn, n_iters=30): + """ + PARAMS + ------ + magnitudes: spectrogram magnitudes + stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods + """ + + angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size()))) + angles = angles.astype(np.float32) + angles = torch.autograd.Variable(torch.from_numpy(angles)) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + + for i in range(n_iters): + _, angles = stft_fn.transform(signal) + signal = stft_fn.inverse(magnitudes, angles).squeeze(1) + return signal + + +def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5): + """ + PARAMS + ------ + C: compression factor + """ + return normalize_fun(torch.clamp(x, min=clip_val) * C) + + +def dynamic_range_decompression(x, C=1): + """ + PARAMS + ------ + C: compression factor used to compress + """ + return torch.exp(x) / C diff --git a/audioldm/audio/stft.py b/audioldm/audio/stft.py new file mode 100644 index 0000000000000000000000000000000000000000..b4acef4e7823f5b8ecc58a770b9f3400906864aa --- /dev/null +++ b/audioldm/audio/stft.py @@ -0,0 +1,186 @@ +import torch +import torch.nn.functional as F +import numpy as np +from scipy.signal import get_window +from librosa.util import pad_center, tiny +from librosa.filters import mel as librosa_mel_fn + +from audioldm.audio.audio_processing import ( + dynamic_range_compression, + dynamic_range_decompression, + window_sumsquare, +) + + +class STFT(torch.nn.Module): + """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft""" + + def __init__(self, filter_length, hop_length, win_length, window="hann"): + super(STFT, self).__init__() + self.filter_length = filter_length + self.hop_length = hop_length + self.win_length = win_length + self.window = window + self.forward_transform = None + scale = self.filter_length / self.hop_length + fourier_basis = np.fft.fft(np.eye(self.filter_length)) + + cutoff = int((self.filter_length / 2 + 1)) + fourier_basis = np.vstack( + [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])] + ) + + forward_basis = torch.FloatTensor(fourier_basis[:, None, :]) + inverse_basis = torch.FloatTensor( + np.linalg.pinv(scale * fourier_basis).T[:, None, :] + ) + + if window is not None: + assert filter_length >= win_length + # get window and zero center pad it to filter_length + fft_window = get_window(window, win_length, fftbins=True) + fft_window = pad_center(fft_window, filter_length) + fft_window = torch.from_numpy(fft_window).float() + + # window the bases + forward_basis *= fft_window + inverse_basis *= fft_window + + self.register_buffer("forward_basis", forward_basis.float()) + self.register_buffer("inverse_basis", inverse_basis.float()) + + def transform(self, input_data): + device = self.forward_basis.device + input_data = input_data.to(device) + + num_batches = input_data.size(0) + num_samples = input_data.size(1) + + self.num_samples = num_samples + + # similar to librosa, reflect-pad the input + input_data = input_data.view(num_batches, 1, num_samples) + input_data = F.pad( + input_data.unsqueeze(1), + (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0), + mode="reflect", + ) + input_data = input_data.squeeze(1) + + forward_transform = F.conv1d( + input_data, + torch.autograd.Variable(self.forward_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + )#.cpu() + + cutoff = int((self.filter_length / 2) + 1) + real_part = forward_transform[:, :cutoff, :] + imag_part = forward_transform[:, cutoff:, :] + + magnitude = torch.sqrt(real_part**2 + imag_part**2) + phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data)) + + return magnitude, phase + + def inverse(self, magnitude, phase): + device = self.forward_basis.device + magnitude, phase = magnitude.to(device), phase.to(device) + + recombine_magnitude_phase = torch.cat( + [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1 + ) + + inverse_transform = F.conv_transpose1d( + recombine_magnitude_phase, + torch.autograd.Variable(self.inverse_basis, requires_grad=False), + stride=self.hop_length, + padding=0, + ) + + if self.window is not None: + window_sum = window_sumsquare( + self.window, + magnitude.size(-1), + hop_length=self.hop_length, + win_length=self.win_length, + n_fft=self.filter_length, + dtype=np.float32, + ) + # remove modulation effects + approx_nonzero_indices = torch.from_numpy( + np.where(window_sum > tiny(window_sum))[0] + ) + window_sum = torch.autograd.Variable( + torch.from_numpy(window_sum), requires_grad=False + ) + window_sum = window_sum + inverse_transform[:, :, approx_nonzero_indices] /= window_sum[ + approx_nonzero_indices + ] + + # scale by hop ratio + inverse_transform *= float(self.filter_length) / self.hop_length + + inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :] + inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :] + + return inverse_transform + + def forward(self, input_data): + self.magnitude, self.phase = self.transform(input_data) + reconstruction = self.inverse(self.magnitude, self.phase) + return reconstruction + + +class TacotronSTFT(torch.nn.Module): + def __init__( + self, + filter_length, + hop_length, + win_length, + n_mel_channels, + sampling_rate, + mel_fmin, + mel_fmax, + ): + super(TacotronSTFT, self).__init__() + self.n_mel_channels = n_mel_channels + self.sampling_rate = sampling_rate + self.stft_fn = STFT(filter_length, hop_length, win_length) + mel_basis = librosa_mel_fn( + sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax + ) + mel_basis = torch.from_numpy(mel_basis).float() + self.register_buffer("mel_basis", mel_basis) + + def spectral_normalize(self, magnitudes, normalize_fun): + output = dynamic_range_compression(magnitudes, normalize_fun) + return output + + def spectral_de_normalize(self, magnitudes): + output = dynamic_range_decompression(magnitudes) + return output + + def mel_spectrogram(self, y, normalize_fun=torch.log): + """Computes mel-spectrograms from a batch of waves + PARAMS + ------ + y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1] + + RETURNS + ------- + mel_output: torch.FloatTensor of shape (B, n_mel_channels, T) + """ + assert torch.min(y.data) >= -1, torch.min(y.data) + assert torch.max(y.data) <= 1, torch.max(y.data) + + magnitudes, phases = self.stft_fn.transform(y) + magnitudes = magnitudes.data + mel_output = torch.matmul(self.mel_basis, magnitudes) + mel_output = self.spectral_normalize(mel_output, normalize_fun) + energy = torch.norm(magnitudes, dim=1) + + log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun) + + return mel_output, log_magnitudes, energy diff --git a/audioldm/audio/tools.py b/audioldm/audio/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..d641a982664b6673822c8528a1929c593f011b11 --- /dev/null +++ b/audioldm/audio/tools.py @@ -0,0 +1,85 @@ +import torch +import numpy as np +import torchaudio + + +def get_mel_from_wav(audio, _stft): + audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) + melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32) + log_magnitudes_stft = ( + torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32) + ) + energy = torch.squeeze(energy, 0).numpy().astype(np.float32) + return melspec, log_magnitudes_stft, energy + + +def _pad_spec(fbank, target_length=1024): + n_frames = fbank.shape[0] + p = target_length - n_frames + # cut and pad + if p > 0: + m = torch.nn.ZeroPad2d((0, 0, 0, p)) + fbank = m(fbank) + elif p < 0: + fbank = fbank[0:target_length, :] + + if fbank.size(-1) % 2 != 0: + fbank = fbank[..., :-1] + + return fbank + + +def pad_wav(waveform, segment_length): + waveform_length = waveform.shape[-1] + assert waveform_length > 100, "Waveform is too short, %s" % waveform_length + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + elif waveform_length < segment_length: + temp_wav = np.zeros((1, segment_length)) + temp_wav[:, :waveform_length] = waveform + return temp_wav + +def normalize_wav(waveform): + waveform = waveform - np.mean(waveform) + waveform = waveform / (np.max(np.abs(waveform)) + 1e-8) + return waveform * 0.5 + + +def read_wav_file(filename, segment_length): + # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) + waveform = waveform.numpy()[0, ...] + waveform = normalize_wav(waveform) + waveform = waveform[None, ...] + waveform = pad_wav(waveform, segment_length) + + waveform = waveform / np.max(np.abs(waveform)) + waveform = 0.5 * waveform + + return waveform + + +def wav_to_fbank(filename, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + # mixup + waveform = read_wav_file(filename, target_length * 160) # hop size is 160 + + waveform = waveform[0, ...] + waveform = torch.FloatTensor(waveform) + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + + fbank = torch.FloatTensor(fbank.T) + log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform diff --git a/audioldm/clap/__init__.py b/audioldm/clap/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/clap/encoders.py b/audioldm/clap/encoders.py new file mode 100644 index 0000000000000000000000000000000000000000..77d5e5c47c9dacf44406e9d00a831bfa051f4214 --- /dev/null +++ b/audioldm/clap/encoders.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +from audioldm.clap.open_clip import create_model +from audioldm.clap.training.data import get_audio_features +import torchaudio +from transformers import RobertaTokenizer +import torch.nn.functional as F + + +class CLAPAudioEmbeddingClassifierFreev2(nn.Module): + def __init__( + self, + pretrained_path="", + key="class", + sampling_rate=16000, + embed_mode="audio", + amodel = "HTSAT-tiny", + unconditional_prob=0.1, + random_mute=False, + max_random_mute_portion=0.5, + training_mode=True, + ): + super().__init__() + + self.key = key + self.device = "cpu" + self.precision = "fp32" + self.amodel = amodel # or 'PANN-14' + self.tmodel = "roberta" # the best text encoder in our training + self.enable_fusion = False # False if you do not want to use the fusion model + self.fusion_type = "aff_2d" + self.pretrained = pretrained_path + self.embed_mode = embed_mode + self.embed_mode_orig = embed_mode + self.sampling_rate = sampling_rate + self.unconditional_prob = unconditional_prob + self.random_mute = random_mute + self.tokenize = RobertaTokenizer.from_pretrained("roberta-base") + self.max_random_mute_portion = max_random_mute_portion + self.training_mode = training_mode + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device=self.device, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + + self.model.eval() + + def get_unconditional_condition(self, batchsize): + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + return torch.cat([self.unconditional_token.unsqueeze(0)] * batchsize, dim=0) + + def batch_to_list(self, batch): + ret = [] + for i in range(batch.size(0)): + ret.append(batch[i]) + return ret + + def make_decision(self, probability): + if float(torch.rand(1)) < probability: + return True + else: + return False + + def random_uniform(self, start, end): + val = torch.rand(1).item() + return start + (end - start) * val + + def _random_mute(self, waveform): + # waveform: [bs, t-steps] + t_steps = waveform.size(-1) + for i in range(waveform.size(0)): + mute_size = int( + self.random_uniform(0, end=int(t_steps * self.max_random_mute_portion)) + ) + mute_start = int(self.random_uniform(0, t_steps - mute_size)) + waveform[i, mute_start : mute_start + mute_size] = 0 + return waveform + + def cos_similarity(self, waveform, text): + # waveform: [bs, t_steps] + with torch.no_grad(): + self.embed_mode = "audio" + audio_emb = self(waveform.cuda()) + self.embed_mode = "text" + text_emb = self(text) + similarity = F.cosine_similarity(audio_emb, text_emb, dim=2), audio_emb, text_emb + return similarity.squeeze() + + def forward(self, batch, key=None): + # If you want this conditioner to be unconditional, set self.unconditional_prob = 1.0 + # If you want this conditioner to be fully conditional, set self.unconditional_prob = 0.0 + if self.model.training == True and not self.training_mode: + print( + "The pretrained CLAP model should always be in eval mode. Reloading model just in case you change the parameters." + ) + self.model, self.model_cfg = create_model( + self.amodel, + self.tmodel, + self.pretrained, + precision=self.precision, + device="cuda", + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + for p in self.model.parameters(): + p.requires_grad = False + self.model.eval() + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + if self.embed_mode == "audio": + with torch.no_grad(): + audio_dict_list = [] + assert ( + self.sampling_rate == 16000 + ), "We only support 16000 sampling rate" + if self.random_mute: + batch = self._random_mute(batch) + # batch: [bs, 1, t-samples] + batch = torchaudio.functional.resample( + batch, orig_freq=self.sampling_rate, new_freq=48000 + ) + for waveform in self.batch_to_list(batch): + audio_dict = {} + audio_dict = get_audio_features( + audio_dict, + waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=self.model_cfg["audio_cfg"], + ) + audio_dict_list.append(audio_dict) + # [bs, 512] + embed = self.model.get_audio_embedding(audio_dict_list) + elif self.embed_mode == "text": + with torch.no_grad(): + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + text_data = self.tokenizer(batch) + embed = self.model.get_text_embedding(text_data) + + embed = embed.unsqueeze(1) + self.unconditional_token = self.model.get_text_embedding( + self.tokenizer(["", ""]) + )[0:1] + + for i in range(embed.size(0)): + if self.make_decision(self.unconditional_prob): + embed[i] = self.unconditional_token + + # [bs, 1, 512] + return embed.detach() + + def tokenizer(self, text): + result = self.tokenize( + text, + padding="max_length", + truncation=True, + max_length=512, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} \ No newline at end of file diff --git a/audioldm/clap/open_clip/__init__.py b/audioldm/clap/open_clip/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e9f728f2f273be5d5fdbec6c6cc41d737176a8c0 --- /dev/null +++ b/audioldm/clap/open_clip/__init__.py @@ -0,0 +1,25 @@ +from .factory import ( + list_models, + create_model, + create_model_and_transforms, + add_model_config, +) +from .loss import ClipLoss, gather_features, LPLoss, lp_gather_features, LPMetrics +from .model import ( + CLAP, + CLAPTextCfg, + CLAPVisionCfg, + CLAPAudioCfp, + convert_weights_to_fp16, + trace_model, +) +from .openai import load_openai_model, list_openai_models +from .pretrained import ( + list_pretrained, + list_pretrained_tag_models, + list_pretrained_model_tags, + get_pretrained_url, + download_pretrained, +) +from .tokenizer import SimpleTokenizer, tokenize +from .transform import image_transform diff --git a/audioldm/clap/open_clip/bert.py b/audioldm/clap/open_clip/bert.py new file mode 100644 index 0000000000000000000000000000000000000000..a83d96d2a77ed05198efc05837522bc88d2499cc --- /dev/null +++ b/audioldm/clap/open_clip/bert.py @@ -0,0 +1,40 @@ +from transformers import BertTokenizer, BertModel + +tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") +model = BertModel.from_pretrained("bert-base-uncased") +text = "Replace me by any text you'd like." + + +def bert_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import RobertaTokenizer, RobertaModel + +tokenizer = RobertaTokenizer.from_pretrained("roberta-base") +model = RobertaModel.from_pretrained("roberta-base") +text = "Replace me by any text you'd like." + + +def Roberta_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output + + +from transformers import BartTokenizer, BartModel + +tokenizer = BartTokenizer.from_pretrained("facebook/bart-base") +model = BartModel.from_pretrained("facebook/bart-base") +text = "Replace me by any text you'd like." + + +def bart_embeddings(text): + # text = "Replace me by any text you'd like." + encoded_input = tokenizer(text, return_tensors="pt") + output = model(**encoded_input) + return output diff --git a/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz new file mode 100644 index 0000000000000000000000000000000000000000..36a15856e00a06a9fbed8cdd34d2393fea4a3113 --- /dev/null +++ b/audioldm/clap/open_clip/bpe_simple_vocab_16e6.txt.gz @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:924691ac288e54409236115652ad4aa250f48203de50a9e4722a6ecd48d6804a +size 1356917 diff --git a/audioldm/clap/open_clip/factory.py b/audioldm/clap/open_clip/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..64d5368bf4f14bd9592472de73d6162a93b16d73 --- /dev/null +++ b/audioldm/clap/open_clip/factory.py @@ -0,0 +1,279 @@ +import json +import logging +import os +import pathlib +import re +from copy import deepcopy +from pathlib import Path + +import torch + +from .model import CLAP, convert_weights_to_fp16 +from .openai import load_openai_model +from .pretrained import get_pretrained_url, download_pretrained +from .transform import image_transform + +_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"] +_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs +CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache/audioldm") + + + +def _natural_key(string_): + return [int(s) if s.isdigit() else s for s in re.split(r"(\d+)", string_.lower())] + + +def _rescan_model_configs(): + global _MODEL_CONFIGS + + config_ext = (".json",) + config_files = [] + for config_path in _MODEL_CONFIG_PATHS: + if config_path.is_file() and config_path.suffix in config_ext: + config_files.append(config_path) + elif config_path.is_dir(): + for ext in config_ext: + config_files.extend(config_path.glob(f"*{ext}")) + + for cf in config_files: + if os.path.basename(cf)[0] == ".": + continue # Ignore hidden files + + with open(cf, "r") as f: + model_cfg = json.load(f) + if all(a in model_cfg for a in ("embed_dim", "audio_cfg", "text_cfg")): + _MODEL_CONFIGS[cf.stem] = model_cfg + + _MODEL_CONFIGS = { + k: v + for k, v in sorted(_MODEL_CONFIGS.items(), key=lambda x: _natural_key(x[0])) + } + + +_rescan_model_configs() # initial populate of model config registry + + +def load_state_dict(checkpoint_path: str, map_location="cpu", skip_params=True): + checkpoint = torch.load(checkpoint_path, map_location=map_location) + if isinstance(checkpoint, dict) and "state_dict" in checkpoint: + state_dict = checkpoint["state_dict"] + else: + state_dict = checkpoint + if skip_params: + if next(iter(state_dict.items()))[0].startswith("module"): + state_dict = {k[7:]: v for k, v in state_dict.items()} + # for k in state_dict: + # if k.startswith('transformer'): + # v = state_dict.pop(k) + # state_dict['text_branch.' + k[12:]] = v + return state_dict + + +def create_model( + amodel_name: str, + tmodel_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + openai_model_cache_dir: str = os.path.expanduser(f"{CACHE_DIR}/clip"), + skip_params=True, + pretrained_audio: str = "", + pretrained_text: str = "", + enable_fusion: bool = False, + fusion_type: str = "None" + # pretrained_image: bool = False, +): + amodel_name = amodel_name.replace( + "/", "-" + ) # for callers using old naming with / in ViT names + pretrained_orig = pretrained + pretrained = pretrained.lower() + if pretrained == "openai": + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + logging.info(f"Loading pretrained ViT-B-16 text encoder from OpenAI.") + # Hard Code in model name + model_cfg["text_cfg"]["model_type"] = tmodel_name + model = load_openai_model( + "ViT-B-16", + model_cfg, + device=device, + jit=jit, + cache_dir=openai_model_cache_dir, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # See https://discuss.pytorch.org/t/valueerror-attemting-to-unscale-fp16-gradients/81372 + if precision == "amp" or precision == "fp32": + model = model.float() + else: + if amodel_name in _MODEL_CONFIGS: + logging.info(f"Loading {amodel_name} model config.") + model_cfg = deepcopy(_MODEL_CONFIGS[amodel_name]) + else: + logging.error( + f"Model config for {amodel_name} not found; available models {list_models()}." + ) + raise RuntimeError(f"Model config for {amodel_name} not found.") + + if force_quick_gelu: + # override for use of QuickGELU on non-OpenAI transformer models + model_cfg["quick_gelu"] = True + + # if pretrained_image: + # if 'timm_amodel_name' in model_cfg.get('vision_cfg', {}): + # # pretrained weight loading for timm models set via vision_cfg + # model_cfg['vision_cfg']['timm_model_pretrained'] = True + # else: + # assert False, 'pretrained image towers currently only supported for timm models' + model_cfg["text_cfg"]["model_type"] = tmodel_name + model_cfg["enable_fusion"] = enable_fusion + model_cfg["fusion_type"] = fusion_type + model = CLAP(**model_cfg) + + if pretrained: + checkpoint_path = "" + url = get_pretrained_url(amodel_name, pretrained) + if url: + checkpoint_path = download_pretrained(url, root=openai_model_cache_dir) + elif os.path.exists(pretrained_orig): + checkpoint_path = pretrained_orig + if checkpoint_path: + logging.info( + f"Loading pretrained {amodel_name}-{tmodel_name} weights ({pretrained})." + ) + ckpt = load_state_dict(checkpoint_path, skip_params=True) + model.load_state_dict(ckpt) + param_names = [n for n, p in model.named_parameters()] + # for n in param_names: + # print(n, "\t", "Loaded" if n in ckpt else "Unloaded") + else: + logging.warning( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + raise RuntimeError( + f"Pretrained weights ({pretrained}) not found for model {amodel_name}." + ) + + if pretrained_audio: + if amodel_name.startswith("PANN"): + if "Cnn14_mAP" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["model"] + keys = list(audio_ckpt.keys()) + for key in keys: + if ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key] = v + elif os.path.basename(pretrained_audio).startswith( + "PANN" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + elif amodel_name.startswith("HTSAT"): + if "HTSAT_AudioSet_Saved" in pretrained_audio: # official checkpoint + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model") and ( + "spectrogram_extractor" not in key + and "logmel_extractor" not in key + ): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "HTSAT" + ): # checkpoint trained via HTSAT codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + audio_ckpt = audio_ckpt["state_dict"] + keys = list(audio_ckpt.keys()) + for key in keys: + if key.startswith("sed_model"): + v = audio_ckpt.pop(key) + audio_ckpt["audio_branch." + key[10:]] = v + elif os.path.basename(pretrained_audio).startswith( + "finetuned" + ): # checkpoint trained via linear probe codebase + audio_ckpt = torch.load(pretrained_audio, map_location="cpu") + else: + raise ValueError("Unknown audio checkpoint") + else: + raise f"this audio encoder pretrained checkpoint is not support" + + model.load_state_dict(audio_ckpt, strict=False) + logging.info( + f"Loading pretrained {amodel_name} weights ({pretrained_audio})." + ) + param_names = [n for n, p in model.named_parameters()] + for n in param_names: + print(n, "\t", "Loaded" if n in audio_ckpt else "Unloaded") + + model.to(device=device) + if precision == "fp16": + assert device.type != "cpu" + convert_weights_to_fp16(model) + + if jit: + model = torch.jit.script(model) + + return model, model_cfg + + +def create_model_and_transforms( + model_name: str, + pretrained: str = "", + precision: str = "fp32", + device: torch.device = torch.device("cpu"), + jit: bool = False, + force_quick_gelu: bool = False, + # pretrained_image: bool = False, +): + model = create_model( + model_name, + pretrained, + precision, + device, + jit, + force_quick_gelu=force_quick_gelu, + # pretrained_image=pretrained_image + ) + preprocess_train = image_transform(model.visual.image_size, is_train=True) + preprocess_val = image_transform(model.visual.image_size, is_train=False) + return model, preprocess_train, preprocess_val + + +def list_models(): + """enumerate available model architectures based on config files""" + return list(_MODEL_CONFIGS.keys()) + + +def add_model_config(path): + """add model config path or file and update registry""" + if not isinstance(path, Path): + path = Path(path) + _MODEL_CONFIG_PATHS.append(path) + _rescan_model_configs() \ No newline at end of file diff --git a/audioldm/clap/open_clip/feature_fusion.py b/audioldm/clap/open_clip/feature_fusion.py new file mode 100644 index 0000000000000000000000000000000000000000..dbe4e170e05894c12ebdc36ba1dc1de65e441b89 --- /dev/null +++ b/audioldm/clap/open_clip/feature_fusion.py @@ -0,0 +1,192 @@ +""" +Feature Fusion for Varible-Length Data Processing +AFF/iAFF is referred and modified from https://github.com/YimianDai/open-aff/blob/master/aff_pytorch/aff_net/fusion.py +According to the paper: Yimian Dai et al, Attentional Feature Fusion, IEEE Winter Conference on Applications of Computer Vision, WACV 2021 +""" + +import torch +import torch.nn as nn + + +class DAF(nn.Module): + """ + 直接相加 DirectAddFuse + """ + + def __init__(self): + super(DAF, self).__init__() + + def forward(self, x, residual): + return x + residual + + +class iAFF(nn.Module): + """ + 多特征融合 iAFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(iAFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + # 本地注意力 + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 全局注意力 + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + + # 第二次本地注意力 + self.local_att2 = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + # 第二次全局注意力 + self.global_att2 = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported" + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xi = x * wei + residual * (1 - wei) + + xl2 = self.local_att2(xi) + xg2 = self.global_att(xi) + xlg2 = xl2 + xg2 + wei2 = self.sigmoid(xlg2) + xo = x * wei2 + residual * (1 - wei2) + if flag: + xo = xo[0].unsqueeze(0) + return xo + + +class AFF(nn.Module): + """ + 多特征融合 AFF + """ + + def __init__(self, channels=64, r=4, type="2D"): + super(AFF, self).__init__() + inter_channels = int(channels // r) + + if type == "1D": + self.local_att = nn.Sequential( + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool1d(1), + nn.Conv1d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv1d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm1d(channels), + ) + elif type == "2D": + self.local_att = nn.Sequential( + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + self.global_att = nn.Sequential( + nn.AdaptiveAvgPool2d(1), + nn.Conv2d(channels, inter_channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(inter_channels), + nn.ReLU(inplace=True), + nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0), + nn.BatchNorm2d(channels), + ) + else: + raise f"the type is not supported." + + self.sigmoid = nn.Sigmoid() + + def forward(self, x, residual): + flag = False + xa = x + residual + if xa.size(0) == 1: + xa = torch.cat([xa, xa], dim=0) + flag = True + xl = self.local_att(xa) + xg = self.global_att(xa) + xlg = xl + xg + wei = self.sigmoid(xlg) + xo = 2 * x * wei + 2 * residual * (1 - wei) + if flag: + xo = xo[0].unsqueeze(0) + return xo diff --git a/audioldm/clap/open_clip/htsat.py b/audioldm/clap/open_clip/htsat.py new file mode 100644 index 0000000000000000000000000000000000000000..3b856c6a43df162116a941f1b5c76e93713b276a --- /dev/null +++ b/audioldm/clap/open_clip/htsat.py @@ -0,0 +1,1308 @@ +# Ke Chen +# knutchen@ucsd.edu +# HTS-AT: A HIERARCHICAL TOKEN-SEMANTIC AUDIO TRANSFORMER FOR SOUND CLASSIFICATION AND DETECTION +# Some layers designed on the model +# below codes are based and referred from https://github.com/microsoft/Swin-Transformer +# Swin Transformer for Computer Vision: https://arxiv.org/pdf/2103.14030.pdf + +import torch +import torch.nn as nn +import torch.nn.functional as F +from itertools import repeat +import collections.abc +import math +import warnings + +from torch.nn.init import _calculate_fan_in_and_fan_out +import torch.utils.checkpoint as checkpoint + +import random + +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from itertools import repeat +from .utils import do_mixup, interpolate + +from .feature_fusion import iAFF, AFF, DAF + +# from PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable): + return x + return tuple(repeat(x, n)) + + return parse + + +to_1tuple = _ntuple(1) +to_2tuple = _ntuple(2) +to_3tuple = _ntuple(3) +to_4tuple = _ntuple(4) +to_ntuple = _ntuple + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + """ + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class PatchEmbed(nn.Module): + """2D Image to Patch Embedding""" + + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + norm_layer=None, + flatten=True, + patch_stride=16, + enable_fusion=False, + fusion_type="None", + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patch_stride = to_2tuple(patch_stride) + self.img_size = img_size + self.patch_size = patch_size + self.patch_stride = patch_stride + self.grid_size = ( + img_size[0] // patch_stride[0], + img_size[1] // patch_stride[1], + ) + self.num_patches = self.grid_size[0] * self.grid_size[1] + self.flatten = flatten + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + padding = ( + (patch_size[0] - patch_stride[0]) // 2, + (patch_size[1] - patch_stride[1]) // 2, + ) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.proj = nn.Conv2d( + in_chans * 4, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + else: + self.proj = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=patch_size, + stride=patch_stride, + padding=padding, + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Conv2d( + in_chans, + embed_dim, + kernel_size=(patch_size[0], patch_size[1] * 3), + stride=(patch_stride[0], patch_stride[1] * 3), + padding=padding, + ) + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=embed_dim, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=embed_dim, type="2D") + + def forward(self, x, longer_idx=None): + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + global_x = self.proj(global_x) + TW = global_x.size(-1) + if len(longer_idx) > 0: + # local processing + local_x = x[longer_idx, 1:, :, :].contiguous() + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 3, 1, 4)).contiguous().flatten(3) + TB, TC, TH, _ = local_x.size() + if local_x.size(-1) < TW: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH, TW - local_x.size(-1)), + device=global_x.device, + ), + ], + dim=-1, + ) + else: + local_x = local_x[:, :, :, :TW] + + global_x[longer_idx] = self.fusion_model(global_x[longer_idx], local_x) + x = global_x + else: + B, C, H, W = x.shape + assert ( + H == self.img_size[0] and W == self.img_size[1] + ), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})." + x = self.proj(x) + + if self.flatten: + x = x.flatten(2).transpose(1, 2) # BCHW -> BNC + x = self.norm(x) + return x + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn( + "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2, + ) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.0)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"): + fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) + if mode == "fan_in": + denom = fan_in + elif mode == "fan_out": + denom = fan_out + elif mode == "fan_avg": + denom = (fan_in + fan_out) / 2 + + variance = scale / denom + + if distribution == "truncated_normal": + # constant is stddev of standard normal truncated to (-2, 2) + trunc_normal_(tensor, std=math.sqrt(variance) / 0.87962566103423978) + elif distribution == "normal": + tensor.normal_(std=math.sqrt(variance)) + elif distribution == "uniform": + bound = math.sqrt(3 * variance) + tensor.uniform_(-bound, bound) + else: + raise ValueError(f"invalid distribution {distribution}") + + +def lecun_normal_(tensor): + variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal") + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view(B, H // window_size, window_size, W // window_size, window_size, C) + windows = ( + x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size / window_size)) + x = windows.view( + B, H // window_size, W // window_size, window_size, window_size, -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + qk_scale=None, + attn_drop=0.0, + proj_drop=0.0, + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim**-0.5 + + # define a parameter table of relative position bias + self.relative_position_bias_table = nn.Parameter( + torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads) + ) # 2*Wh-1 * 2*Ww-1, nH + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + trunc_normal_(self.relative_position_bias_table, std=0.02) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B_, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + q = q * self.scale + attn = q @ k.transpose(-2, -1) + + relative_position_bias = self.relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + attn = self.softmax(attn) + else: + attn = self.softmax(attn) + + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}" + + +# We use the model based on Swintransformer Block, therefore we can use the swin-transformer pretrained model +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + norm_before_mlp="ln", + ): + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + self.norm_before_mlp = norm_before_mlp + if min(self.input_resolution) <= self.window_size: + # if window size is larger than input resolution, we don't partition windows + self.shift_size = 0 + self.window_size = min(self.input_resolution) + assert ( + 0 <= self.shift_size < self.window_size + ), "shift_size must in 0-window_size" + + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=to_2tuple(self.window_size), + num_heads=num_heads, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + if self.norm_before_mlp == "ln": + self.norm2 = nn.LayerNorm(dim) + elif self.norm_before_mlp == "bn": + self.norm2 = lambda x: nn.BatchNorm1d(dim)(x.transpose(1, 2)).transpose( + 1, 2 + ) + else: + raise NotImplementedError + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + if self.shift_size > 0: + # calculate attention mask for SW-MSA + H, W = self.input_resolution + img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + h_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + w_slices = ( + slice(0, -self.window_size), + slice(-self.window_size, -self.shift_size), + slice(-self.shift_size, None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill( + attn_mask != 0, float(-100.0) + ).masked_fill(attn_mask == 0, float(0.0)) + else: + attn_mask = None + + self.register_buffer("attn_mask", attn_mask) + + def forward(self, x): + # pdb.set_trace() + H, W = self.input_resolution + # print("H: ", H) + # print("W: ", W) + # pdb.set_trace() + B, L, C = x.shape + # assert L == H * W, "input feature has wrong size" + + shortcut = x + x = self.norm1(x) + x = x.view(B, H, W, C) + + # cyclic shift + if self.shift_size > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2) + ) + else: + shifted_x = x + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size * self.window_size, C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows, attn = self.attn( + x_windows, mask=self.attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C) + shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C + + # reverse cyclic shift + if self.shift_size > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2) + ) + else: + x = shifted_x + x = x.view(B, H * W, C) + + # FFN + x = shortcut + self.drop_path(x) + x = x + self.drop_path(self.mlp(self.norm2(x))) + + return x, attn + + def extra_repr(self): + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.input_resolution = input_resolution + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(4 * dim) + + def forward(self, x): + """ + x: B, H*W, C + """ + H, W = self.input_resolution + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.norm(x) + x = self.reduction(x) + + return x + + def extra_repr(self): + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set. + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + norm_before_mlp="ln", + ): + + super().__init__() + self.dim = dim + self.input_resolution = input_resolution + self.depth = depth + self.use_checkpoint = use_checkpoint + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=window_size, + shift_size=0 if (i % 2 == 0) else window_size // 2, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop, + attn_drop=attn_drop, + drop_path=drop_path[i] + if isinstance(drop_path, list) + else drop_path, + norm_layer=norm_layer, + norm_before_mlp=norm_before_mlp, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample( + input_resolution, dim=dim, norm_layer=norm_layer + ) + else: + self.downsample = None + + def forward(self, x): + attns = [] + for blk in self.blocks: + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x) + else: + x, attn = blk(x) + if not self.training: + attns.append(attn.unsqueeze(0)) + if self.downsample is not None: + x = self.downsample(x) + if not self.training: + attn = torch.cat(attns, dim=0) + attn = torch.mean(attn, dim=0) + return x, attn + + def extra_repr(self): + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + +# The Core of HTSAT +class HTSAT_Swin_Transformer(nn.Module): + r"""HTSAT based on the Swin Transformer + Args: + spec_size (int | tuple(int)): Input Spectrogram size. Default 256 + patch_size (int | tuple(int)): Patch size. Default: 4 + path_stride (iot | tuple(int)): Patch Stride for Frequency and Time Axis. Default: 4 + in_chans (int): Number of input image channels. Default: 1 (mono) + num_classes (int): Number of classes for classification head. Default: 527 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each HTSAT-Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 8 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + config (module): The configuration Module from config.py + """ + + def __init__( + self, + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + in_chans=1, + num_classes=527, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + mlp_ratio=4.0, + qkv_bias=True, + qk_scale=None, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + norm_before_mlp="ln", + config=None, + enable_fusion=False, + fusion_type="None", + **kwargs, + ): + super(HTSAT_Swin_Transformer, self).__init__() + + self.config = config + self.spec_size = spec_size + self.patch_stride = patch_stride + self.patch_size = patch_size + self.window_size = window_size + self.embed_dim = embed_dim + self.depths = depths + self.ape = ape + self.in_chans = in_chans + self.num_classes = num_classes + self.num_heads = num_heads + self.num_layers = len(self.depths) + self.num_features = int(self.embed_dim * 2 ** (self.num_layers - 1)) + + self.drop_rate = drop_rate + self.attn_drop_rate = attn_drop_rate + self.drop_path_rate = drop_path_rate + + self.qkv_bias = qkv_bias + self.qk_scale = None + + self.patch_norm = patch_norm + self.norm_layer = norm_layer if self.patch_norm else None + self.norm_before_mlp = norm_before_mlp + self.mlp_ratio = mlp_ratio + + self.use_checkpoint = use_checkpoint + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # process mel-spec ; used only once + self.freq_ratio = self.spec_size // self.config.mel_bins + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + self.interpolate_ratio = 32 # Downsampled ratio + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=config.window_size, + hop_length=config.hop_size, + win_length=config.window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=config.sample_rate, + n_fft=config.window_size, + n_mels=config.mel_bins, + fmin=config.fmin, + fmax=config.fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) # 2 2 + self.bn0 = nn.BatchNorm2d(self.config.mel_bins) + + # split spctrogram into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=self.spec_size, + patch_size=self.patch_size, + in_chans=self.in_chans, + embed_dim=self.embed_dim, + norm_layer=self.norm_layer, + patch_stride=patch_stride, + enable_fusion=self.enable_fusion, + fusion_type=self.fusion_type, + ) + + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.grid_size + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, self.embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=self.drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, self.drop_path_rate, sum(self.depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(self.embed_dim * 2**i_layer), + input_resolution=( + patches_resolution[0] // (2**i_layer), + patches_resolution[1] // (2**i_layer), + ), + depth=self.depths[i_layer], + num_heads=self.num_heads[i_layer], + window_size=self.window_size, + mlp_ratio=self.mlp_ratio, + qkv_bias=self.qkv_bias, + qk_scale=self.qk_scale, + drop=self.drop_rate, + attn_drop=self.attn_drop_rate, + drop_path=dpr[ + sum(self.depths[:i_layer]) : sum(self.depths[: i_layer + 1]) + ], + norm_layer=self.norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + norm_before_mlp=self.norm_before_mlp, + ) + self.layers.append(layer) + + self.norm = self.norm_layer(self.num_features) + self.avgpool = nn.AdaptiveAvgPool1d(1) + self.maxpool = nn.AdaptiveMaxPool1d(1) + + SF = ( + self.spec_size + // (2 ** (len(self.depths) - 1)) + // self.patch_stride[0] + // self.freq_ratio + ) + self.tscam_conv = nn.Conv2d( + in_channels=self.num_features, + out_channels=self.num_classes, + kernel_size=(SF, 3), + padding=(0, 1), + ) + self.head = nn.Linear(num_classes, num_classes) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"relative_position_bias_table"} + + def forward_features(self, x, longer_idx=None): + # A deprecated optimization for using a hierarchical output from different blocks + + frames_num = x.shape[2] + x = self.patch_embed(x, longer_idx=longer_idx) + if self.ape: + x = x + self.absolute_pos_embed + x = self.pos_drop(x) + for i, layer in enumerate(self.layers): + x, attn = layer(x) + # for x + x = self.norm(x) + B, N, C = x.shape + SF = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[0] + ST = frames_num // (2 ** (len(self.depths) - 1)) // self.patch_stride[1] + x = x.permute(0, 2, 1).contiguous().reshape(B, C, SF, ST) + B, C, F, T = x.shape + # group 2D CNN + c_freq_bin = F // self.freq_ratio + x = x.reshape(B, C, F // c_freq_bin, c_freq_bin, T) + x = x.permute(0, 1, 3, 2, 4).contiguous().reshape(B, C, c_freq_bin, -1) + # get latent_output + fine_grained_latent_output = torch.mean(x, dim=2) + fine_grained_latent_output = interpolate( + fine_grained_latent_output.permute(0, 2, 1).contiguous(), + 8 * self.patch_stride[1], + ) + + latent_output = self.avgpool(torch.flatten(x, 2)) + latent_output = torch.flatten(latent_output, 1) + + # display the attention map, if needed + + x = self.tscam_conv(x) + x = torch.flatten(x, 2) # B, C, T + + fpx = interpolate( + torch.sigmoid(x).permute(0, 2, 1).contiguous(), 8 * self.patch_stride[1] + ) + + x = self.avgpool(x) + x = torch.flatten(x, 1) + + output_dict = { + "framewise_output": fpx, # already sigmoided + "clipwise_output": torch.sigmoid(x), + "fine_grained_embedding": fine_grained_latent_output, + "embedding": latent_output, + } + + return output_dict + + def crop_wav(self, x, crop_size, spe_pos=None): + time_steps = x.shape[2] + tx = torch.zeros(x.shape[0], x.shape[1], crop_size, x.shape[3]).to(x.device) + for i in range(len(x)): + if spe_pos is None: + crop_pos = random.randint(0, time_steps - crop_size - 1) + else: + crop_pos = spe_pos + tx[i][0] = x[i, 0, crop_pos : crop_pos + crop_size, :] + return tx + + # Reshape the wavform to a img size, if you want to use the pretrained swin transformer model + def reshape_wav2img(self, x): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() + x = x.reshape( + x.shape[0], + x.shape[1], + x.shape[2], + self.freq_ratio, + x.shape[3] // self.freq_ratio, + ) + # print(x.shape) + x = x.permute(0, 1, 3, 2, 4).contiguous() + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3], x.shape[4]) + return x + + # Repeat the wavform to a img size, if you want to use the pretrained swin transformer model + def repeat_wat2img(self, x, cur_pos): + B, C, T, F = x.shape + target_T = int(self.spec_size * self.freq_ratio) + target_F = self.spec_size // self.freq_ratio + assert ( + T <= target_T and F <= target_F + ), "the wav size should less than or equal to the swin input size" + # to avoid bicubic zero error + if T < target_T: + x = nn.functional.interpolate( + x, (target_T, x.shape[3]), mode="bicubic", align_corners=True + ) + if F < target_F: + x = nn.functional.interpolate( + x, (x.shape[2], target_F), mode="bicubic", align_corners=True + ) + x = x.permute(0, 1, 3, 2).contiguous() # B C F T + x = x[:, :, :, cur_pos : cur_pos + self.spec_size] + x = x.repeat(repeats=(1, 1, 4, 1)) + return x + + def forward( + self, x: torch.Tensor, mixup_lambda=None, infer_mode=False, device=None + ): # out_feat_keys: List[str] = None): + + if self.enable_fusion and x["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + x["longer"][torch.randint(0, x["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = x["waveform"].to(device=device, non_blocking=True) + x = self.spectrogram_extractor(x) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.training: + x = self.spec_augmenter(x) + + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x) + else: + longer_list = x["longer"].to(device=device, non_blocking=True) + x = x["mel_fusion"].to(device=device, non_blocking=True) + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + longer_list_idx = torch.where(longer_list)[0] + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + if len(longer_list_idx) > 0: + # local processing + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.reshape_wav2img(x) + output_dict = self.forward_features(x, longer_idx=longer_list_idx) + + # if infer_mode: + # # in infer mode. we need to handle different length audio input + # frame_num = x.shape[2] + # target_T = int(self.spec_size * self.freq_ratio) + # repeat_ratio = math.floor(target_T / frame_num) + # x = x.repeat(repeats=(1,1,repeat_ratio,1)) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # if x.shape[2] > self.freq_ratio * self.spec_size: + # if self.training: + # x = self.crop_wav(x, crop_size=self.freq_ratio * self.spec_size) + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # else: + # # Change: Hard code here + # overlap_size = (x.shape[2] - 1) // 4 + # output_dicts = [] + # crop_size = (x.shape[2] - 1) // 2 + # for cur_pos in range(0, x.shape[2] - crop_size - 1, overlap_size): + # tx = self.crop_wav(x, crop_size = crop_size, spe_pos = cur_pos) + # tx = self.reshape_wav2img(tx) + # output_dicts.append(self.forward_features(tx)) + # clipwise_output = torch.zeros_like(output_dicts[0]["clipwise_output"]).float().to(x.device) + # framewise_output = torch.zeros_like(output_dicts[0]["framewise_output"]).float().to(x.device) + # for d in output_dicts: + # clipwise_output += d["clipwise_output"] + # framewise_output += d["framewise_output"] + # clipwise_output = clipwise_output / len(output_dicts) + # framewise_output = framewise_output / len(output_dicts) + # output_dict = { + # 'framewise_output': framewise_output, + # 'clipwise_output': clipwise_output + # } + # else: # this part is typically used, and most easy one + # x = self.reshape_wav2img(x) + # output_dict = self.forward_features(x) + # x = self.head(x) + + # We process the data in the dataloader part, in that here we only consider the input_T < fixed_T + + return output_dict + + +def create_htsat_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + + assert audio_cfg.model_name in [ + "tiny", + "base", + "large", + ], "model name for HTS-AT is wrong!" + if audio_cfg.model_name == "tiny": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "base": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=128, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + elif audio_cfg.model_name == "large": + model = HTSAT_Swin_Transformer( + spec_size=256, + patch_size=4, + patch_stride=(4, 4), + num_classes=audio_cfg.class_num, + embed_dim=256, + depths=[2, 2, 12, 2], + num_heads=[4, 8, 16, 32], + window_size=8, + config=audio_cfg, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm/clap/open_clip/linear_probe.py b/audioldm/clap/open_clip/linear_probe.py new file mode 100644 index 0000000000000000000000000000000000000000..9d7e23b6b67a53e16d050d675a99d01d7d04d581 --- /dev/null +++ b/audioldm/clap/open_clip/linear_probe.py @@ -0,0 +1,66 @@ +import numpy as np +import torch.nn.functional as F +from torch import nn +from .model import MLPLayers + + +class LinearProbe(nn.Module): + def __init__(self, model, mlp, freeze, in_ch, out_ch, act=None): + """ + Args: + model: nn.Module + mlp: bool, if True, then use the MLP layer as the linear probe module + freeze: bool, if Ture, then freeze all the CLAP model's layers when training the linear probe + in_ch: int, the output channel from CLAP model + out_ch: int, the output channel from linear probe (class_num) + act: torch.nn.functional, the activation function before the loss function + """ + super().__init__() + in_ch = 512 + self.clap_model = model + self.clap_model.text_branch = None # to save memory + self.freeze = freeze + if mlp: + self.lp_layer = MLPLayers(units=[in_ch, in_ch * 2, out_ch]) + else: + self.lp_layer = nn.Linear(in_ch, out_ch) + + if self.freeze: + for param in self.clap_model.parameters(): + param.requires_grad = False + + if act == "None": + self.act = None + elif act == "relu": + self.act = nn.ReLU() + elif act == "elu": + self.act = nn.ELU() + elif act == "prelu": + self.act = nn.PReLU(num_parameters=in_ch) + elif act == "softmax": + self.act = nn.Softmax(dim=-1) + elif act == "sigmoid": + self.act = nn.Sigmoid() + + def forward(self, x, mix_lambda=None, device=None): + """ + Args: + x: waveform, torch.tensor [batch, t_samples] / batch of mel_spec and longer list + mix_lambda: torch.tensor [batch], the mixup lambda + Returns: + class_prob: torch.tensor [batch, class_num] + + """ + # batchnorm cancel grandient + if self.freeze: + self.clap_model.eval() + + x = self.clap_model.audio_projection( + self.clap_model.audio_branch(x, mixup_lambda=mix_lambda, device=device)[ + "embedding" + ] + ) + out = self.lp_layer(x) + if self.act is not None: + out = self.act(out) + return out diff --git a/audioldm/clap/open_clip/loss.py b/audioldm/clap/open_clip/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..cc66298a14997da4aa2efc71e37c0a6bcda53fd1 --- /dev/null +++ b/audioldm/clap/open_clip/loss.py @@ -0,0 +1,398 @@ +from multiprocessing.sharedctypes import Value +import torch +import torch.distributed.nn +from torch import distributed as dist, nn as nn +from torch.nn import functional as F +import numpy as np +from sklearn.metrics import average_precision_score, roc_auc_score, accuracy_score + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def gather_features( + audio_features, + text_features, + audio_features_mlp=None, + text_features_mlp=None, + local_loss=False, + gather_with_grad=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, +): + if use_horovod: + assert hvd is not None, "Please install horovod" + if gather_with_grad: + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + else: + with torch.no_grad(): + all_audio_features = hvd.allgather(audio_features) + all_text_features = hvd.allgather(text_features) + if mlp_loss: + all_audio_features_mlp = hvd.allgather(audio_features_mlp) + all_text_features_mlp = hvd.allgather(text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features = list( + all_audio_features.chunk(world_size, dim=0) + ) + gathered_text_features = list( + all_text_features.chunk(world_size, dim=0) + ) + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + gathered_audio_features_mlp = list( + all_audio_features_mlp.chunk(world_size, dim=0) + ) + gathered_text_features_mlp = list( + all_text_features_mlp.chunk(world_size, dim=0) + ) + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + all_audio_features_mlp = torch.cat( + gathered_audio_features_mlp, dim=0 + ) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + else: + # We gather tensors from all gpus + if gather_with_grad: + all_audio_features = torch.cat( + torch.distributed.nn.all_gather(audio_features), dim=0 + ) + all_text_features = torch.cat( + torch.distributed.nn.all_gather(text_features), dim=0 + ) + if mlp_loss: + all_audio_features_mlp = torch.cat( + torch.distributed.nn.all_gather(audio_features_mlp), dim=0 + ) + all_text_features_mlp = torch.cat( + torch.distributed.nn.all_gather(text_features_mlp), dim=0 + ) + else: + gathered_audio_features = [ + torch.zeros_like(audio_features) for _ in range(world_size) + ] + gathered_text_features = [ + torch.zeros_like(text_features) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features, audio_features) + dist.all_gather(gathered_text_features, text_features) + if mlp_loss: + gathered_audio_features_mlp = [ + torch.zeros_like(audio_features_mlp) for _ in range(world_size) + ] + gathered_text_features_mlp = [ + torch.zeros_like(text_features_mlp) for _ in range(world_size) + ] + dist.all_gather(gathered_audio_features_mlp, audio_features_mlp) + dist.all_gather(gathered_text_features_mlp, text_features_mlp) + if not local_loss: + # ensure grads for local rank when all_* features don't have a gradient + gathered_audio_features[rank] = audio_features + gathered_text_features[rank] = text_features + if mlp_loss: + gathered_audio_features_mlp[rank] = audio_features_mlp + gathered_text_features_mlp[rank] = text_features_mlp + + all_audio_features = torch.cat(gathered_audio_features, dim=0) + all_text_features = torch.cat(gathered_text_features, dim=0) + if mlp_loss: + all_audio_features_mlp = torch.cat(gathered_audio_features_mlp, dim=0) + all_text_features_mlp = torch.cat(gathered_text_features_mlp, dim=0) + if mlp_loss: + return ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) + else: + return all_audio_features, all_text_features + + +class ClipLoss(nn.Module): + def __init__( + self, + local_loss=False, + gather_with_grad=False, + cache_labels=False, + rank=0, + world_size=1, + use_horovod=False, + mlp_loss=False, + weight_loss_kappa=0, + ): + super().__init__() + self.local_loss = local_loss + self.gather_with_grad = gather_with_grad + self.cache_labels = cache_labels + self.rank = rank + self.world_size = world_size + self.use_horovod = use_horovod + self.mlp_loss = mlp_loss + self.weighted_loss = bool(weight_loss_kappa != 0) + self.weight_loss_kappa = weight_loss_kappa + # cache state + self.prev_num_logits = 0 + self.labels = {} + + def forward( + self, + audio_features, + text_features, + logit_scale_a, + logit_scale_t=None, + audio_features_mlp=None, + text_features_mlp=None, + ): + device = audio_features.device + if self.mlp_loss: + if self.world_size > 1: + ( + all_audio_features, + all_text_features, + all_audio_features_mlp, + all_text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + if self.local_loss: + a_logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = ( + logit_scale_a * text_features_mlp @ all_audio_features.T + ) + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = ( + logit_scale_t * text_features @ all_audio_features_mlp.T + ) + else: + a_logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features_mlp.T + ) + a_logits_per_text = a_logits_per_audio.T + t_logits_per_audio = ( + logit_scale_t * all_audio_features_mlp @ all_text_features.T + ) + t_logits_per_text = t_logits_per_audio.T + else: + a_logits_per_audio = ( + logit_scale_a * audio_features @ text_features_mlp.T + ) + a_logits_per_text = logit_scale_a * text_features_mlp @ audio_features.T + t_logits_per_audio = ( + logit_scale_t * audio_features_mlp @ text_features.T + ) + t_logits_per_text = logit_scale_t * text_features @ audio_features_mlp.T + + # calculated ground-truth and cache if enabled + num_logits = a_logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + else: + audio_weight = (audio_features @ audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(audio_weight)) + ) + ).detach() + text_weight = (text_features @ text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels, weight=audio_weight) + + F.cross_entropy(a_logits_per_text, labels, weight=audio_weight) + + F.cross_entropy(t_logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(t_logits_per_text, labels, weight=text_weight) + ) / 4 + else: + if self.world_size > 1: + all_audio_features, all_text_features = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=self.local_loss, + gather_with_grad=self.gather_with_grad, + rank=self.rank, + world_size=self.world_size, + use_horovod=self.use_horovod, + mlp_loss=self.mlp_loss, + ) + + if self.local_loss: + logits_per_audio = ( + logit_scale_a * audio_features @ all_text_features.T + ) + logits_per_text = ( + logit_scale_a * text_features @ all_audio_features.T + ) + else: + logits_per_audio = ( + logit_scale_a * all_audio_features @ all_text_features.T + ) + logits_per_text = logits_per_audio.T + else: + logits_per_audio = logit_scale_a * audio_features @ text_features.T + logits_per_text = logit_scale_a * text_features @ audio_features.T + + # calculated ground-truth and cache if enabled + num_logits = logits_per_audio.shape[0] + if self.prev_num_logits != num_logits or device not in self.labels: + labels = torch.arange(num_logits, device=device, dtype=torch.long) + if self.world_size > 1 and self.local_loss: + labels = labels + num_logits * self.rank + if self.cache_labels: + self.labels[device] = labels + self.prev_num_logits = num_logits + else: + labels = self.labels[device] + if not self.weighted_loss: + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + else: + audio_weight = (all_audio_features @ all_audio_features.T).detach() + audio_weight = ( + torch.exp( + torch.sum(audio_weight, axis=1) + / (self.weight_loss_kappa * len(all_audio_features)) + ) + ).detach() + text_weight = (all_text_features @ all_text_features.T).detach() + text_weight = ( + torch.exp( + torch.sum(text_weight, axis=1) + / (self.weight_loss_kappa * len(all_text_features)) + ) + ).detach() + total_loss = ( + F.cross_entropy(logits_per_audio, labels, weight=text_weight) + + F.cross_entropy(logits_per_text, labels, weight=audio_weight) + ) / 2 + return total_loss + + +def lp_gather_features(pred, target, world_size=1, use_horovod=False): + if use_horovod: + assert hvd is not None, "Please install horovod" + with torch.no_grad(): + all_preds = hvd.allgather(pred) + all_targets = hvd.allgath(target) + else: + gathered_preds = [torch.zeros_like(pred) for _ in range(world_size)] + gathered_targets = [torch.zeros_like(target) for _ in range(world_size)] + + dist.all_gather(gathered_preds, pred) + dist.all_gather(gathered_targets, target) + all_preds = torch.cat(gathered_preds, dim=0) + all_targets = torch.cat(gathered_targets, dim=0) + + return all_preds, all_targets + + +def get_map(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(average_precision_score(target, pred, average=None)) + + +def get_acc(pred, target): + pred = torch.argmax(pred, 1).numpy() + target = torch.argmax(target, 1).numpy() + return accuracy_score(target, pred) + + +def get_mauc(pred, target): + pred = torch.sigmoid(pred).numpy() + target = target.numpy() + return np.mean(roc_auc_score(target, pred, average=None)) + + +class LPMetrics(object): + def __init__(self, metric_names=["map", "acc", "mauc"]): + self.metrics = [] + for name in metric_names: + self.metrics.append(self.get_metric(name)) + self.metric_names = metric_names + + def get_metric(self, name): + if name == "map": + return get_map + elif name == "acc": + return get_acc + elif name == "mauc": + return get_mauc + else: + raise ValueError(f"the metric should be at least one of [map, acc, mauc]") + + def evaluate_mertics(self, pred, target): + metric_dict = {} + for i in range(len(self.metric_names)): + metric_dict[self.metric_names[i]] = self.metrics[i](pred, target) + return metric_dict + + +def calc_celoss(pred, target): + target = torch.argmax(target, 1).long() + return nn.CrossEntropyLoss()(pred, target) + + +class LPLoss(nn.Module): + def __init__(self, loss_name): + super().__init__() + if loss_name == "bce": + self.loss_func = nn.BCEWithLogitsLoss() + elif loss_name == "ce": + self.loss_func = calc_celoss + elif loss_name == "mse": + self.loss_func = nn.MSELoss() + else: + raise ValueError(f"the loss func should be at least one of [bce, ce, mse]") + + def forward(self, pred, target): + loss = self.loss_func(pred, target) + return loss diff --git a/audioldm/clap/open_clip/model.py b/audioldm/clap/open_clip/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b439244f8c293a0b4263b7ac1fd553e9d0adf184 --- /dev/null +++ b/audioldm/clap/open_clip/model.py @@ -0,0 +1,936 @@ +""" CLAP Model + +Adapted from CLIP: https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +Adapted to the Audio Task. +""" + +from collections import OrderedDict +from dataclasses import dataclass +from email.mime import audio +from typing import Tuple, Union, Callable, Optional + +import numpy as np +import torch +import torch.nn.functional as F +from torch import nn + +from .timm_model import TimmModel +import logging +from .utils import freeze_batch_norm_2d + +from .pann_model import create_pann_model +from .htsat import create_htsat_model +from transformers import BertModel, RobertaModel, BartModel +from transformers.tokenization_utils_base import BatchEncoding + + +class MLPLayers(nn.Module): + def __init__(self, units=[512, 512, 512], nonlin=nn.ReLU(), dropout=0.1): + super(MLPLayers, self).__init__() + self.nonlin = nonlin + self.dropout = dropout + + sequence = [] + for u0, u1 in zip(units[:-1], units[1:]): + sequence.append(nn.Linear(u0, u1)) + sequence.append(self.nonlin) + sequence.append(nn.Dropout(self.dropout)) + sequence = sequence[:-2] + + self.sequential = nn.Sequential(*sequence) + + def forward(self, X): + X = self.sequential(X) + return X + + +class Bottleneck(nn.Module): + expansion = 4 + + def __init__(self, inplanes, planes, stride=1): + super().__init__() + + # all conv layers have stride 1. an avgpool is performed after the second convolution when stride > 1 + self.conv1 = nn.Conv2d(inplanes, planes, 1, bias=False) + self.bn1 = nn.BatchNorm2d(planes) + + self.conv2 = nn.Conv2d(planes, planes, 3, padding=1, bias=False) + self.bn2 = nn.BatchNorm2d(planes) + + self.avgpool = nn.AvgPool2d(stride) if stride > 1 else nn.Identity() + + self.conv3 = nn.Conv2d(planes, planes * self.expansion, 1, bias=False) + self.bn3 = nn.BatchNorm2d(planes * self.expansion) + + self.relu = nn.ReLU(inplace=True) + self.downsample = None + self.stride = stride + + if stride > 1 or inplanes != planes * Bottleneck.expansion: + # downsampling layer is prepended with an avgpool, and the subsequent convolution has stride 1 + self.downsample = nn.Sequential( + OrderedDict( + [ + ("-1", nn.AvgPool2d(stride)), + ( + "0", + nn.Conv2d( + inplanes, + planes * self.expansion, + 1, + stride=1, + bias=False, + ), + ), + ("1", nn.BatchNorm2d(planes * self.expansion)), + ] + ) + ) + + def forward(self, x: torch.Tensor): + identity = x + + out = self.relu(self.bn1(self.conv1(x))) + out = self.relu(self.bn2(self.conv2(out))) + out = self.avgpool(out) + out = self.bn3(self.conv3(out)) + + if self.downsample is not None: + identity = self.downsample(x) + + out += identity + out = self.relu(out) + return out + + +class AttentionPool2d(nn.Module): + def __init__( + self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None + ): + super().__init__() + self.positional_embedding = nn.Parameter( + torch.randn(spacial_dim**2 + 1, embed_dim) / embed_dim**0.5 + ) + self.k_proj = nn.Linear(embed_dim, embed_dim) + self.q_proj = nn.Linear(embed_dim, embed_dim) + self.v_proj = nn.Linear(embed_dim, embed_dim) + self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) + self.num_heads = num_heads + + def forward(self, x): + x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3]).permute( + 2, 0, 1 + ) # NCHW -> (HW)NC + x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (HW+1)NC + x = x + self.positional_embedding[:, None, :].to(x.dtype) # (HW+1)NC + x, _ = F.multi_head_attention_forward( + query=x, + key=x, + value=x, + embed_dim_to_check=x.shape[-1], + num_heads=self.num_heads, + q_proj_weight=self.q_proj.weight, + k_proj_weight=self.k_proj.weight, + v_proj_weight=self.v_proj.weight, + in_proj_weight=None, + in_proj_bias=torch.cat( + [self.q_proj.bias, self.k_proj.bias, self.v_proj.bias] + ), + bias_k=None, + bias_v=None, + add_zero_attn=False, + dropout_p=0, + out_proj_weight=self.c_proj.weight, + out_proj_bias=self.c_proj.bias, + use_separate_proj_weight=True, + training=self.training, + need_weights=False, + ) + + return x[0] + + +class ModifiedResNet(nn.Module): + """ + A ResNet class that is similar to torchvision's but contains the following changes: + - There are now 3 "stem" convolutions as opposed to 1, with an average pool instead of a max pool. + - Performs anti-aliasing strided convolutions, where an avgpool is prepended to convolutions with stride > 1 + - The final pooling layer is a QKV attention instead of an average pool + """ + + def __init__(self, layers, output_dim, heads, image_size=224, width=64): + super().__init__() + self.output_dim = output_dim + self.image_size = image_size + + # the 3-layer stem + self.conv1 = nn.Conv2d( + 3, width // 2, kernel_size=3, stride=2, padding=1, bias=False + ) + self.bn1 = nn.BatchNorm2d(width // 2) + self.conv2 = nn.Conv2d( + width // 2, width // 2, kernel_size=3, padding=1, bias=False + ) + self.bn2 = nn.BatchNorm2d(width // 2) + self.conv3 = nn.Conv2d(width // 2, width, kernel_size=3, padding=1, bias=False) + self.bn3 = nn.BatchNorm2d(width) + self.avgpool = nn.AvgPool2d(2) + self.relu = nn.ReLU(inplace=True) + + # residual layers + self._inplanes = width # this is a *mutable* variable used during construction + self.layer1 = self._make_layer(width, layers[0]) + self.layer2 = self._make_layer(width * 2, layers[1], stride=2) + self.layer3 = self._make_layer(width * 4, layers[2], stride=2) + self.layer4 = self._make_layer(width * 8, layers[3], stride=2) + + embed_dim = width * 32 # the ResNet feature dimension + self.attnpool = AttentionPool2d(image_size // 32, embed_dim, heads, output_dim) + + self.init_parameters() + + def _make_layer(self, planes, blocks, stride=1): + layers = [Bottleneck(self._inplanes, planes, stride)] + + self._inplanes = planes * Bottleneck.expansion + for _ in range(1, blocks): + layers.append(Bottleneck(self._inplanes, planes)) + + return nn.Sequential(*layers) + + def init_parameters(self): + if self.attnpool is not None: + std = self.attnpool.c_proj.in_features**-0.5 + nn.init.normal_(self.attnpool.q_proj.weight, std=std) + nn.init.normal_(self.attnpool.k_proj.weight, std=std) + nn.init.normal_(self.attnpool.v_proj.weight, std=std) + nn.init.normal_(self.attnpool.c_proj.weight, std=std) + + for resnet_block in [self.layer1, self.layer2, self.layer3, self.layer4]: + for name, param in resnet_block.named_parameters(): + if name.endswith("bn3.weight"): + nn.init.zeros_(param) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self) + + def stem(self, x): + for conv, bn in [ + (self.conv1, self.bn1), + (self.conv2, self.bn2), + (self.conv3, self.bn3), + ]: + x = self.relu(bn(conv(x))) + x = self.avgpool(x) + return x + + def forward(self, x): + x = self.stem(x) + x = self.layer1(x) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + x = self.attnpool(x) + + return x + + +class LayerNorm(nn.LayerNorm): + """Subclass torch's LayerNorm to handle fp16.""" + + def forward(self, x: torch.Tensor): + orig_type = x.dtype + x = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) + return x.to(orig_type) + + +class QuickGELU(nn.Module): + # NOTE This is slower than nn.GELU or nn.SiLU and uses more GPU memory + def forward(self, x: torch.Tensor): + return x * torch.sigmoid(1.702 * x) + + +class ResidualAttentionBlock(nn.Module): + def __init__(self, d_model: int, n_head: int, act_layer: Callable = nn.GELU): + super().__init__() + + self.attn = nn.MultiheadAttention(d_model, n_head) + self.ln_1 = LayerNorm(d_model) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, d_model * 4)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(d_model * 4, d_model)), + ] + ) + ) + self.ln_2 = LayerNorm(d_model) + + def attention(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + return self.attn(x, x, x, need_weights=False, attn_mask=attn_mask)[0] + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + x = x + self.attention(self.ln_1(x), attn_mask=attn_mask) + x = x + self.mlp(self.ln_2(x)) + return x + + +class Transformer(nn.Module): + def __init__( + self, width: int, layers: int, heads: int, act_layer: Callable = nn.GELU + ): + super().__init__() + self.width = width + self.layers = layers + self.resblocks = nn.ModuleList( + [ + ResidualAttentionBlock(width, heads, act_layer=act_layer) + for _ in range(layers) + ] + ) + + def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): + for r in self.resblocks: + x = r(x, attn_mask=attn_mask) + return x + + +class VisualTransformer(nn.Module): + def __init__( + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + output_dim: int, + act_layer: Callable = nn.GELU, + ): + super().__init__() + self.image_size = image_size + self.output_dim = output_dim + self.conv1 = nn.Conv2d( + in_channels=3, + out_channels=width, + kernel_size=patch_size, + stride=patch_size, + bias=False, + ) + + scale = width**-0.5 + self.class_embedding = nn.Parameter(scale * torch.randn(width)) + self.positional_embedding = nn.Parameter( + scale * torch.randn((image_size // patch_size) ** 2 + 1, width) + ) + self.ln_pre = LayerNorm(width) + + self.text_branch = Transformer(width, layers, heads, act_layer=act_layer) + + self.ln_post = LayerNorm(width) + self.proj = nn.Parameter(scale * torch.randn(width, output_dim)) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + assert ( + unlocked_groups == 0 + ), "partial locking not currently supported for this model" + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x: torch.Tensor): + x = self.conv1(x) # shape = [*, width, grid, grid] + x = x.reshape(x.shape[0], x.shape[1], -1) # shape = [*, width, grid ** 2] + x = x.permute(0, 2, 1) # shape = [*, grid ** 2, width] + x = torch.cat( + [ + self.class_embedding.to(x.dtype) + + torch.zeros( + x.shape[0], 1, x.shape[-1], dtype=x.dtype, device=x.device + ), + x, + ], + dim=1, + ) # shape = [*, grid ** 2 + 1, width] + x = x + self.positional_embedding.to(x.dtype) + x = self.ln_pre(x) + + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x) + x = x.permute(1, 0, 2) # LND -> NLD + + x = self.ln_post(x[:, 0, :]) + + if self.proj is not None: + x = x @ self.proj + + return x + + +@dataclass +class CLAPVisionCfg: + layers: Union[Tuple[int, int, int, int], int] = 12 + width: int = 768 + patch_size: int = 16 + image_size: Union[Tuple[int, int], int] = 224 + timm_model_name: str = ( + None # a valid model name overrides layers, width, patch_size + ) + timm_model_pretrained: bool = ( + False # use (imagenet) pretrained weights for named model + ) + timm_pool: str = ( + "avg" # feature pooling for timm model ('abs_attn', 'rot_attn', 'avg', '') + ) + timm_proj: str = ( + "linear" # linear projection for timm model output ('linear', 'mlp', '') + ) + + +# Audio Config Class +@dataclass +class CLAPAudioCfp: + model_type: str = "PANN" + model_name: str = "Cnn14" + sample_rate: int = 48000 + # Param + audio_length: int = 1024 + window_size: int = 1024 + hop_size: int = 1024 + fmin: int = 50 + fmax: int = 14000 + class_num: int = 527 + mel_bins: int = 64 + clip_samples: int = 480000 + + +@dataclass +class CLAPTextCfg: + context_length: int + vocab_size: int + width: int + heads: int + layers: int + model_type: str + + +class CLAP(nn.Module): + def __init__( + self, + embed_dim: int, + audio_cfg: CLAPAudioCfp, + text_cfg: CLAPTextCfg, + quick_gelu: bool = False, + enable_fusion: bool = False, + fusion_type: str = "None", + joint_embed_shape: int = 512, + mlp_act: str = "relu", + ): + super().__init__() + if isinstance(audio_cfg, dict): + audio_cfg = CLAPAudioCfp(**audio_cfg) + if isinstance(text_cfg, dict): + text_cfg = CLAPTextCfg(**text_cfg) + + self.audio_cfg = audio_cfg + self.text_cfg = text_cfg + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + self.joint_embed_shape = joint_embed_shape + self.mlp_act = mlp_act + + self.context_length = text_cfg.context_length + + # OpenAI models are pretrained w/ QuickGELU but native nn.GELU is both faster and more + # memory efficient in recent PyTorch releases (>= 1.10). + # NOTE: timm models always use native GELU regardless of quick_gelu flag. + act_layer = QuickGELU if quick_gelu else nn.GELU + + if mlp_act == "relu": + mlp_act_layer = nn.ReLU() + elif mlp_act == "gelu": + mlp_act_layer = nn.GELU() + else: + raise NotImplementedError + + # audio branch + # audio branch parameters + if audio_cfg.model_type == "PANN": + self.audio_branch = create_pann_model(audio_cfg, enable_fusion, fusion_type) + elif audio_cfg.model_type == "HTSAT": + self.audio_branch = create_htsat_model( + audio_cfg, enable_fusion, fusion_type + ) + else: + logging.error(f"Model config for {audio_cfg.model_type} not found") + raise RuntimeError(f"Model config for {audio_cfg.model_type} not found.") + + # text branch + # text branch parameters + if text_cfg.model_type == "transformer": + self.text_branch = Transformer( + width=text_cfg.width, + layers=text_cfg.layers, + heads=text_cfg.heads, + act_layer=act_layer, + ) + self.vocab_size = text_cfg.vocab_size + self.token_embedding = nn.Embedding(text_cfg.vocab_size, text_cfg.width) + self.positional_embedding = nn.Parameter( + torch.empty(self.context_length, text_cfg.width) + ) + self.ln_final = LayerNorm(text_cfg.width) + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(text_cfg.width, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bert": + self.text_branch = BertModel.from_pretrained("bert-base-uncased") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "roberta": + self.text_branch = RobertaModel.from_pretrained("roberta-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + elif text_cfg.model_type == "bart": + self.text_branch = BartModel.from_pretrained("facebook/bart-base") + self.text_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + self.text_projection = nn.Sequential( + nn.Linear(768, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + else: + logging.error(f"Model config for {text_cfg.model_type} not found") + raise RuntimeError(f"Model config for {text_cfg.model_type} not found.") + self.text_branch_type = text_cfg.model_type + # text branch parameters + + # audio branch parameters + self.audio_transform = MLPLayers( + units=[ + self.joint_embed_shape, + self.joint_embed_shape, + self.joint_embed_shape, + ], + dropout=0.1, + ) + + # below here is text branch parameters + + # ============================================================================================================ + self.audio_projection = nn.Sequential( + nn.Linear(embed_dim, self.joint_embed_shape), + mlp_act_layer, + nn.Linear(self.joint_embed_shape, self.joint_embed_shape), + ) + + self.logit_scale_a = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.logit_scale_t = nn.Parameter(torch.ones([]) * np.log(1 / 0.07)) + self.register_buffer("attn_mask", self.build_attention_mask(), persistent=False) + + self.init_text_branch_parameters() + + def init_text_branch_parameters(self): + if self.text_branch_type == "transformer": + nn.init.normal_(self.token_embedding.weight, std=0.02) + nn.init.normal_(self.positional_embedding, std=0.01) + proj_std = (self.text_branch.width**-0.5) * ( + (2 * self.text_branch.layers) ** -0.5 + ) + attn_std = self.text_branch.width**-0.5 + fc_std = (2 * self.text_branch.width) ** -0.5 + for block in self.text_branch.resblocks: + nn.init.normal_(block.attn.in_proj_weight, std=attn_std) + nn.init.normal_(block.attn.out_proj.weight, std=proj_std) + nn.init.normal_(block.mlp.c_fc.weight, std=fc_std) + nn.init.normal_(block.mlp.c_proj.weight, std=proj_std) + if self.text_branch_type == "bert" or self.text_branch_type == "roberta": + width = self.text_branch.embeddings.word_embeddings.weight.shape[-1] + elif self.text_branch_type == "bart": + width = self.text_branch.shared.weight.shape[-1] + else: + width = self.text_branch.width + nn.init.constant_(self.logit_scale_a, np.log(1 / 0.07)) + nn.init.constant_(self.logit_scale_t, np.log(1 / 0.07)) + + # deprecated + # if hasattr(self.visual, 'init_parameters'): + # self.visual.init_parameters() + + # if self.text_projection is not None: + # nn.init.normal_(self.text_projection, std=width**-0.5) + + def build_attention_mask(self): + # lazily create causal attention mask, with full attention between the vision tokens + # pytorch uses additive attention mask; fill with -inf + mask = torch.empty(self.context_length, self.context_length) + mask.fill_(float("-inf")) + mask.triu_(1) # zero out the lower diagonal + return mask + + def encode_audio(self, audio, device): + return self.audio_branch( + audio, mixup_lambda=None, device=device + ) # mix lambda needs to add + + # def list_of_dict_of_tensor2dict_of_tensor(self, x, device): + # tmp = {} + # for k in x[0].keys(): + # tmp[k] = [] + # for i in range(len(x)): + # tmp[k].append(x[i][k][:77]) + # for k in x[0].keys(): + # tmp[k] = torch.tensor(tmp[k]).to(device=device, non_blocking=True) + # return tmp + + def encode_text(self, text, device): + if self.text_branch_type == "transformer": + text = text.to(device=device, non_blocking=True) + x = self.token_embedding(text) # [batch_size, n_ctx, d_model] + + x = x + self.positional_embedding + x = x.permute(1, 0, 2) # NLD -> LND + x = self.text_branch(x, attn_mask=self.attn_mask) + x = x.permute(1, 0, 2) # LND -> NLD + x = self.ln_final(x) + + # x.shape = [batch_size, n_ctx, transformer.width] + # take features from the eot embedding (eot_token is the highest number in each sequence) + x = self.text_projection(x[torch.arange(x.shape[0]), text.argmax(dim=-1)]) + elif self.text_branch_type == "bert": + # text = self.list_of_dict_of_tensor2dict_of_tensor(text, device) + # text = BatchEncoding(text) + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + token_type_ids=text["token_type_ids"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "roberta": + x = self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["pooler_output"] + x = self.text_projection(x) + elif self.text_branch_type == "bart": + x = torch.mean( + self.text_branch( + input_ids=text["input_ids"].to(device=device, non_blocking=True), + attention_mask=text["attention_mask"].to( + device=device, non_blocking=True + ), + )["encoder_last_hidden_state"], + axis=1, + ) + x = self.text_projection(x) + else: + logging.error(f"Model type {self.text_branch_type} not found") + raise RuntimeError(f"Model type {self.text_branch_type} not found.") + return x + + def forward(self, audio, text, device=None): + """Forward audio and text into the CLAP + + Parameters + ---------- + audio: torch.Tensor (batch_size, audio_length) + the time-domain audio input / the batch of mel_spec and longer list. + text: torch.Tensor () // need to add + the text token input + """ + if device is None: + if audio is not None: + device = audio.device + elif text is not None: + device = text.device + if audio is None and text is None: + # a hack to get the logit scale + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + elif audio is None: + return self.encode_text(text, device=device) + elif text is None: + return self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = self.audio_projection( + self.encode_audio(audio, device=device)["embedding"] + ) + audio_features = F.normalize(audio_features, dim=-1) + + text_features = self.encode_text(text, device=device) + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + # print("text_features.type", type(text_features)) + text_features = F.normalize(text_features, dim=-1) + + audio_features_mlp = self.audio_transform(audio_features) + text_features_mlp = self.text_transform(text_features) + # Four outputs: audio features (basic & MLP), text features (basic & MLP) + return ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + self.logit_scale_a.exp(), + self.logit_scale_t.exp(), + ) + + def get_logit_scale(self): + return self.logit_scale_a.exp(), self.logit_scale_t.exp() + + def get_text_embedding(self, data): + """Get the text embedding from the model + + Parameters + ---------- + data: torch.Tensor + a tensor of text embedding + + Returns + ---------- + text_embed: torch.Tensor + a tensor of text_embeds (N, D) + + """ + device = next(self.parameters()).device + for k in data: + data[k] = data[k].to(device) + if len(data[k].size()) < 2: + data[k] = data[k].unsqueeze(0) + text_embeds = self.encode_text(data, device=device) + text_embeds = F.normalize(text_embeds, dim=-1) + + return text_embeds + + def get_audio_embedding(self, data): + """Get the audio embedding from the model + + Parameters + ---------- + data: a list of dict + the audio input dict list from 'get_audio_feature' method + + Returns + ---------- + audio_embed: torch.Tensor + a tensor of audio_embeds (N, D) + + """ + device = next(self.parameters()).device + input_dict = {} + keys = data[0].keys() + for k in keys: + input_dict[k] = torch.cat([d[k].unsqueeze(0) for d in data], dim=0).to( + device + ) + + audio_embeds = self.audio_projection( + self.encode_audio(input_dict, device=device)["embedding"] + ) + audio_embeds = F.normalize(audio_embeds, dim=-1) + + return audio_embeds + + def audio_infer(self, audio, hopsize=None, device=None): + """Forward one audio and produce the audio embedding + + Parameters + ---------- + audio: (audio_length) + the time-domain audio input, notice that it must be only one input + hopsize: int + the overlap hopsize as the sliding window + + Returns + ---------- + output_dict: { + key: [n, (embedding_shape)] if "HTS-AT" + or + key: [(embedding_shape)] if "PANN" + } + the list of key values of the audio branch + + """ + + assert not self.training, "the inference mode must be run at eval stage" + output_dict = {} + # PANN + if self.audio_cfg.model_type == "PANN": + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + elif self.audio_cfg.model_type == "HTSAT": + # repeat + audio_len = len(audio) + k = self.audio_cfg.clip_samples // audio_len + if k > 1: + audio = audio.repeat(k) + audio_len = len(audio) + + if hopsize is None: + hopsize = min(hopsize, audio_len) + + if audio_len > self.audio_cfg.clip_samples: + audio_input = [ + audio[pos : pos + self.audio_cfg.clip_samples].clone() + for pos in range( + 0, audio_len - self.audio_cfg.clip_samples, hopsize + ) + ] + audio_input.append(audio[-self.audio_cfg.clip_samples :].clone()) + audio_input = torch.stack(audio_input) + output_dict[key] = self.encode_audio(audio_input, device=device)[key] + else: + audio_input = audio.unsqueeze(dim=0) + output_dict[key] = self.encode_audio(audio_input, device=device)[ + key + ].squeeze(dim=0) + + return output_dict + + +def convert_weights_to_fp16(model: nn.Module): + """Convert applicable model parameters to fp16""" + + def _convert_weights_to_fp16(l): + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Linear)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + if isinstance(l, nn.MultiheadAttention): + for attr in [ + *[f"{s}_proj_weight" for s in ["in", "q", "k", "v"]], + "in_proj_bias", + "bias_k", + "bias_v", + ]: + tensor = getattr(l, attr) + if tensor is not None: + tensor.data = tensor.data.half() + + for name in ["text_projection", "proj"]: + if hasattr(l, name): + attr = getattr(l, name) + if attr is not None: + attr.data = attr.data.half() + + model.apply(_convert_weights_to_fp16) + + +# Ignore the state dict of the vision part +def build_model_from_openai_state_dict( + state_dict: dict, model_cfg, enable_fusion: bool = False, fusion_type: str = "None" +): + + embed_dim = model_cfg["embed_dim"] + audio_cfg = model_cfg["audio_cfg"] + text_cfg = model_cfg["text_cfg"] + context_length = state_dict["positional_embedding"].shape[0] + vocab_size = state_dict["token_embedding.weight"].shape[0] + transformer_width = state_dict["ln_final.weight"].shape[0] + transformer_heads = transformer_width // 64 + transformer_layers = len( + set( + k.split(".")[2] + for k in state_dict + if k.startswith(f"transformer.resblocks") + ) + ) + + audio_cfg = CLAPAudioCfp(**audio_cfg) + text_cfg = CLAPTextCfg(**text_cfg) + + model = CLAP( + embed_dim, + audio_cfg=audio_cfg, + text_cfg=text_cfg, + quick_gelu=True, # OpenAI models were trained with QuickGELU + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + state_dict["logit_scale_a"] = state_dict["logit_scale"] + state_dict["logit_scale_t"] = state_dict["logit_scale"] + pop_keys = list(state_dict.keys())[::] + # pop the visual branch saved weights + for key in pop_keys: + if key.startswith("visual."): + state_dict.pop(key, None) + + for key in ["logit_scale", "input_resolution", "context_length", "vocab_size"]: + state_dict.pop(key, None) + + # not use fp16 + # convert_weights_to_fp16(model) + model.load_state_dict(state_dict, strict=False) + return model.eval() + + +def trace_model(model, batch_size=256, device=torch.device("cpu")): + model.eval() + audio_length = model.audio_cfg.audio_length + example_audio = torch.ones((batch_size, audio_length), device=device) + example_text = torch.zeros( + (batch_size, model.context_length), dtype=torch.int, device=device + ) + model = torch.jit.trace_module( + model, + inputs=dict( + forward=(example_audio, example_text), + encode_text=(example_text,), + encode_image=(example_audio,), + ), + ) + model.audio_cfg.audio_length = audio_length # Question: what does this do? + return model diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-base.json b/audioldm/clap/open_clip/model_configs/HTSAT-base.json new file mode 100644 index 0000000000000000000000000000000000000000..6cef625a89daf4431f1c9f72e10bc9640eef2ba8 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-base.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "base" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-large.json b/audioldm/clap/open_clip/model_configs/HTSAT-large.json new file mode 100644 index 0000000000000000000000000000000000000000..699cdb1b16855582606551e4196b24aba2ffd871 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-large.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "large" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json new file mode 100644 index 0000000000000000000000000000000000000000..73e42990fe8361a0df502e7f93d29f19f58c9ecb --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json new file mode 100644 index 0000000000000000000000000000000000000000..a6e7821163d9afa81c27345a1e472475b92af169 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/HTSAT-tiny.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 768, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "HTSAT", + "model_name": "tiny" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-10.json b/audioldm/clap/open_clip/model_configs/PANN-10.json new file mode 100644 index 0000000000000000000000000000000000000000..954ddf62921aed7dde9c37ffffec98a2e96a4ee7 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-10.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 1024, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn10" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json new file mode 100644 index 0000000000000000000000000000000000000000..b7989bc0cd95d0d39049b7524eba508b3e386439 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-18k.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 18000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json new file mode 100644 index 0000000000000000000000000000000000000000..56bdb56bedc304ffa52d8bf5988cea2c1d82d14e --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-fmax-8k-20s.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 960000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 360, + "fmin": 50, + "fmax": 8000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json new file mode 100644 index 0000000000000000000000000000000000000000..5756e3bebc97cc985f512cb081930fee4e49bec1 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-tiny-transformer.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 4 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json new file mode 100644 index 0000000000000000000000000000000000000000..5a9e7e208b661619d5e26625e849da1adda8a475 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14-win-1536.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1536, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-14.json b/audioldm/clap/open_clip/model_configs/PANN-14.json new file mode 100644 index 0000000000000000000000000000000000000000..39a5134cde1d8c50f4758377c952ef22f07bab41 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-14.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 2048, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn14" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/PANN-6.json b/audioldm/clap/open_clip/model_configs/PANN-6.json new file mode 100644 index 0000000000000000000000000000000000000000..21ebc344326de260c386ba77e0ad63cf9b04febf --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/PANN-6.json @@ -0,0 +1,23 @@ +{ + "embed_dim": 512, + "audio_cfg": { + "audio_length": 1024, + "clip_samples": 480000, + "mel_bins": 64, + "sample_rate": 48000, + "window_size": 1024, + "hop_size": 480, + "fmin": 50, + "fmax": 14000, + "class_num": 527, + "model_type": "PANN", + "model_name": "Cnn6" + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..d0db2c161d13138788c4609d373b023b8454d624 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN101-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN101.json b/audioldm/clap/open_clip/model_configs/RN101.json new file mode 100644 index 0000000000000000000000000000000000000000..b88b4d3acbaa701c614ab0ea65fc88fcfe289c32 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN101.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 23, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..8c2f91260cdeb043434dc1e893cce81d4ce7f0d1 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50-quickgelu.json @@ -0,0 +1,22 @@ +{ + "embed_dim": 1024, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} diff --git a/audioldm/clap/open_clip/model_configs/RN50.json b/audioldm/clap/open_clip/model_configs/RN50.json new file mode 100644 index 0000000000000000000000000000000000000000..33aa884d54fee0076c33676831e49d5e1ffcb8f2 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 1024, + "vision_cfg": { + "image_size": 224, + "layers": [ + 3, + 4, + 6, + 3 + ], + "width": 64, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x16.json b/audioldm/clap/open_clip/model_configs/RN50x16.json new file mode 100644 index 0000000000000000000000000000000000000000..3161e1a2c9a839161e652a4d729c2cdc971161db --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50x16.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 384, + "layers": [ + 6, + 8, + 18, + 8 + ], + "width": 96, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/RN50x4.json b/audioldm/clap/open_clip/model_configs/RN50x4.json new file mode 100644 index 0000000000000000000000000000000000000000..e155237f8ce1026aaaeecc80751eabe6f329f0bb --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/RN50x4.json @@ -0,0 +1,21 @@ +{ + "embed_dim": 640, + "vision_cfg": { + "image_size": 288, + "layers": [ + 4, + 6, + 10, + 6 + ], + "width": 80, + "patch_size": null + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 640, + "heads": 10, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-16.json b/audioldm/clap/open_clip/model_configs/ViT-B-16.json new file mode 100644 index 0000000000000000000000000000000000000000..395eea77ec3907c0611531aba63459b193e67b9c --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-16.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 16 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json new file mode 100644 index 0000000000000000000000000000000000000000..ce6bd923593293ed50dfcfb28b73ca7403bcf3c5 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-32-quickgelu.json @@ -0,0 +1,17 @@ +{ + "embed_dim": 512, + "quick_gelu": true, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-B-32.json b/audioldm/clap/open_clip/model_configs/ViT-B-32.json new file mode 100644 index 0000000000000000000000000000000000000000..07c8e28eb06fa1813ba932fe4eec668262d1c47f --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-B-32.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 512, + "vision_cfg": { + "image_size": 224, + "layers": 12, + "width": 768, + "patch_size": 32 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 512, + "heads": 8, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/model_configs/ViT-L-14.json b/audioldm/clap/open_clip/model_configs/ViT-L-14.json new file mode 100644 index 0000000000000000000000000000000000000000..d4a4bbb1dd4ed4edb317d3ace4f3ad13b211c241 --- /dev/null +++ b/audioldm/clap/open_clip/model_configs/ViT-L-14.json @@ -0,0 +1,16 @@ +{ + "embed_dim": 768, + "vision_cfg": { + "image_size": 224, + "layers": 24, + "width": 1024, + "patch_size": 14 + }, + "text_cfg": { + "context_length": 77, + "vocab_size": 49408, + "width": 768, + "heads": 12, + "layers": 12 + } +} \ No newline at end of file diff --git a/audioldm/clap/open_clip/openai.py b/audioldm/clap/open_clip/openai.py new file mode 100644 index 0000000000000000000000000000000000000000..fcb624f54a8b9d2c4b11e3adb50c53c3261716d4 --- /dev/null +++ b/audioldm/clap/open_clip/openai.py @@ -0,0 +1,159 @@ +""" OpenAI pretrained model functions + +Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" + +import os +import warnings +from typing import Union, List + +import torch + +from .model import build_model_from_openai_state_dict +from .pretrained import ( + get_pretrained_url, + list_pretrained_tag_models, + download_pretrained, +) + +__all__ = ["list_openai_models", "load_openai_model"] + +CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") + + + +def list_openai_models() -> List[str]: + """Returns the names of available CLIP models""" + return list_pretrained_tag_models("openai") + + +def load_openai_model( + name: str, + model_cfg, + device: Union[str, torch.device] = "cuda" if torch.cuda.is_available() else "cpu", + jit=True, + cache_dir=os.path.expanduser(f"{CACHE_DIR}/clip"), + enable_fusion: bool = False, + fusion_type: str = "None", +): + """Load a CLIP model, preserve its text pretrained part, and set in the CLAP model + + Parameters + ---------- + name : str + A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict + device : Union[str, torch.device] + The device to put the loaded model + jit : bool + Whether to load the optimized JIT model (default) or more hackable non-JIT model. + + Returns + ------- + model : torch.nn.Module + The CLAP model + preprocess : Callable[[PIL.Image], torch.Tensor] + A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input + """ + if get_pretrained_url(name, "openai"): + model_path = download_pretrained( + get_pretrained_url(name, "openai"), root=cache_dir + ) + elif os.path.isfile(name): + model_path = name + else: + raise RuntimeError( + f"Model {name} not found; available models = {list_openai_models()}" + ) + + try: + # loading JIT archive + model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval() + state_dict = None + except RuntimeError: + # loading saved state dict + if jit: + warnings.warn( + f"File {model_path} is not a JIT archive. Loading as a state dict instead" + ) + jit = False + state_dict = torch.load(model_path, map_location="cpu") + + if not jit: + try: + model = build_model_from_openai_state_dict( + state_dict or model.state_dict(), model_cfg, enable_fusion, fusion_type + ).to(device) + except KeyError: + sd = {k[7:]: v for k, v in state_dict["state_dict"].items()} + model = build_model_from_openai_state_dict( + sd, model_cfg, enable_fusion, fusion_type + ).to(device) + + if str(device) == "cpu": + model.float() + return model + + # patch the device names + device_holder = torch.jit.trace( + lambda: torch.ones([]).to(torch.device(device)), example_inputs=[] + ) + device_node = [ + n + for n in device_holder.graph.findAllNodes("prim::Constant") + if "Device" in repr(n) + ][-1] + + def patch_device(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("prim::Constant"): + if "value" in node.attributeNames() and str(node["value"]).startswith( + "cuda" + ): + node.copyAttributes(device_node) + + model.apply(patch_device) + patch_device(model.encode_audio) + patch_device(model.encode_text) + + # patch dtype to float32 on CPU + if str(device) == "cpu": + float_holder = torch.jit.trace( + lambda: torch.ones([]).float(), example_inputs=[] + ) + float_input = list(float_holder.graph.findNode("aten::to").inputs())[1] + float_node = float_input.node() + + def patch_float(module): + try: + graphs = [module.graph] if hasattr(module, "graph") else [] + except RuntimeError: + graphs = [] + + if hasattr(module, "forward1"): + graphs.append(module.forward1.graph) + + for graph in graphs: + for node in graph.findAllNodes("aten::to"): + inputs = list(node.inputs()) + for i in [ + 1, + 2, + ]: # dtype can be the second or third argument to aten::to() + if inputs[i].node()["value"] == 5: + inputs[i].node().copyAttributes(float_node) + + model.apply(patch_float) + patch_float(model.encode_audio) + patch_float(model.encode_text) + model.float() + + model.audio_branch.audio_length = model.audio_cfg.audio_length + return model diff --git a/audioldm/clap/open_clip/pann_model.py b/audioldm/clap/open_clip/pann_model.py new file mode 100644 index 0000000000000000000000000000000000000000..0d9a8eb0bf897ad6ec04923361b01e5de433b2ef --- /dev/null +++ b/audioldm/clap/open_clip/pann_model.py @@ -0,0 +1,704 @@ +# PANNs: Large-Scale Pretrained Audio Neural Networks for Audio Pattern Recognition +# Reference from https://github.com/qiuqiangkong/audioset_tagging_cnn +# Some layers are re-designed for CLAP +import os + +os.environ["NUMBA_CACHE_DIR"] = "/tmp/" + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchlibrosa.stft import Spectrogram, LogmelFilterBank +from torchlibrosa.augmentation import SpecAugmentation + +from .utils import do_mixup, interpolate, pad_framewise_output +from .feature_fusion import iAFF, AFF, DAF + + +def init_layer(layer): + """Initialize a Linear or Convolutional layer.""" + nn.init.xavier_uniform_(layer.weight) + + if hasattr(layer, "bias"): + if layer.bias is not None: + layer.bias.data.fill_(0.0) + + +def init_bn(bn): + """Initialize a Batchnorm layer.""" + bn.bias.data.fill_(0.0) + bn.weight.data.fill_(1.0) + + +class ConvBlock(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.conv2 = nn.Conv2d( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=(3, 3), + stride=(1, 1), + padding=(1, 1), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + self.bn2 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_layer(self.conv2) + init_bn(self.bn1) + init_bn(self.bn2) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + x = F.relu_(self.bn2(self.conv2(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class ConvBlock5x5(nn.Module): + def __init__(self, in_channels, out_channels): + + super(ConvBlock5x5, self).__init__() + + self.conv1 = nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=(5, 5), + stride=(1, 1), + padding=(2, 2), + bias=False, + ) + + self.bn1 = nn.BatchNorm2d(out_channels) + + self.init_weight() + + def init_weight(self): + init_layer(self.conv1) + init_bn(self.bn1) + + def forward(self, input, pool_size=(2, 2), pool_type="avg"): + + x = input + x = F.relu_(self.bn1(self.conv1(x))) + if pool_type == "max": + x = F.max_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg": + x = F.avg_pool2d(x, kernel_size=pool_size) + elif pool_type == "avg+max": + x1 = F.avg_pool2d(x, kernel_size=pool_size) + x2 = F.max_pool2d(x, kernel_size=pool_size) + x = x1 + x2 + else: + raise Exception("Incorrect argument!") + + return x + + +class AttBlock(nn.Module): + def __init__(self, n_in, n_out, activation="linear", temperature=1.0): + super(AttBlock, self).__init__() + + self.activation = activation + self.temperature = temperature + self.att = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + self.cla = nn.Conv1d( + in_channels=n_in, + out_channels=n_out, + kernel_size=1, + stride=1, + padding=0, + bias=True, + ) + + self.bn_att = nn.BatchNorm1d(n_out) + self.init_weights() + + def init_weights(self): + init_layer(self.att) + init_layer(self.cla) + init_bn(self.bn_att) + + def forward(self, x): + # x: (n_samples, n_in, n_time) + norm_att = torch.softmax(torch.clamp(self.att(x), -10, 10), dim=-1) + cla = self.nonlinear_transform(self.cla(x)) + x = torch.sum(norm_att * cla, dim=2) + return x, norm_att, cla + + def nonlinear_transform(self, x): + if self.activation == "linear": + return x + elif self.activation == "sigmoid": + return torch.sigmoid(x) + + +class Cnn14(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn14, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + if (self.enable_fusion) and (self.fusion_type == "channel_map"): + self.conv_block1 = ConvBlock(in_channels=4, out_channels=64) + else: + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + self.conv_block6 = ConvBlock(in_channels=1024, out_channels=2048) + + self.fc1 = nn.Linear(2048, 2048, bias=True) + self.fc_audioset = nn.Linear(2048, classes_num, bias=True) + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"] + ): + self.mel_conv1d = nn.Sequential( + nn.Conv1d(64, 64, kernel_size=5, stride=3, padding=2), + nn.BatchNorm1d(64), # No Relu + ) + if self.fusion_type == "daf_1d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_1d": + self.fusion_model = AFF(channels=64, type="1D") + elif self.fusion_type == "iaff_1d": + self.fusion_model = iAFF(channels=64, type="1D") + + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + self.mel_conv2d = nn.Sequential( + nn.Conv2d(1, 64, kernel_size=(5, 5), stride=(6, 2), padding=(2, 2)), + nn.BatchNorm2d(64), + nn.ReLU(inplace=True), + ) + + if self.fusion_type == "daf_2d": + self.fusion_model = DAF() + elif self.fusion_type == "aff_2d": + self.fusion_model = AFF(channels=64, type="2D") + elif self.fusion_type == "iaff_2d": + self.fusion_model = iAFF(channels=64, type="2D") + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + if self.enable_fusion and input["longer"].sum() == 0: + # if no audio is longer than 10s, then randomly select one audio to be longer + input["longer"][torch.randint(0, input["longer"].shape[0], (1,))] = True + + if not self.enable_fusion: + x = self.spectrogram_extractor( + input["waveform"].to(device=device, non_blocking=True) + ) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + else: + longer_list = input["longer"].to(device=device, non_blocking=True) + x = input["mel_fusion"].to(device=device, non_blocking=True) + longer_list_idx = torch.where(longer_list)[0] + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + if self.fusion_type in ["daf_1d", "aff_1d", "iaff_1d"]: + new_x = x[:, 0:1, :, :].clone().contiguous() + # local processing + if len(longer_list_idx) > 0: + fusion_x_local = x[longer_list_idx, 1:, :, :].clone().contiguous() + FB, FC, FT, FF = fusion_x_local.size() + fusion_x_local = fusion_x_local.view(FB * FC, FT, FF) + fusion_x_local = torch.permute( + fusion_x_local, (0, 2, 1) + ).contiguous() + fusion_x_local = self.mel_conv1d(fusion_x_local) + fusion_x_local = fusion_x_local.view( + FB, FC, FF, fusion_x_local.size(-1) + ) + fusion_x_local = ( + torch.permute(fusion_x_local, (0, 2, 1, 3)) + .contiguous() + .flatten(2) + ) + if fusion_x_local.size(-1) < FT: + fusion_x_local = torch.cat( + [ + fusion_x_local, + torch.zeros( + (FB, FF, FT - fusion_x_local.size(-1)), + device=device, + ), + ], + dim=-1, + ) + else: + fusion_x_local = fusion_x_local[:, :, :FT] + # 1D fusion + new_x = new_x.squeeze(1).permute((0, 2, 1)).contiguous() + new_x[longer_list_idx] = self.fusion_model( + new_x[longer_list_idx], fusion_x_local + ) + x = new_x.permute((0, 2, 1)).contiguous()[:, None, :, :] + else: + x = new_x + elif self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d", "channel_map"]: + x = x # no change + + if self.training: + x = self.spec_augmenter(x) + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + if (self.enable_fusion) and ( + self.fusion_type in ["daf_2d", "aff_2d", "iaff_2d"] + ): + global_x = x[:, 0:1, :, :] + + # global processing + B, C, H, W = global_x.shape + global_x = self.conv_block1(global_x, pool_size=(2, 2), pool_type="avg") + if len(longer_list_idx) > 0: + local_x = x[longer_list_idx, 1:, :, :].contiguous() + TH = global_x.size(-2) + # local processing + B, C, H, W = local_x.shape + local_x = local_x.view(B * C, 1, H, W) + local_x = self.mel_conv2d(local_x) + local_x = local_x.view( + B, C, local_x.size(1), local_x.size(2), local_x.size(3) + ) + local_x = local_x.permute((0, 2, 1, 3, 4)).contiguous().flatten(2, 3) + TB, TC, _, TW = local_x.size() + if local_x.size(-2) < TH: + local_x = torch.cat( + [ + local_x, + torch.zeros( + (TB, TC, TH - local_x.size(-2), TW), + device=global_x.device, + ), + ], + dim=-2, + ) + else: + local_x = local_x[:, :, :TH, :] + + global_x[longer_list_idx] = self.fusion_model( + global_x[longer_list_idx], local_x + ) + x = global_x + else: + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block6(x, pool_size=(1, 1), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + return output_dict + + +class Cnn6(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn6, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock5x5(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock5x5(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock5x5(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock5x5(in_channels=256, out_channels=512) + + self.fc1 = nn.Linear(512, 512, bias=True) + self.fc_audioset = nn.Linear(512, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 16) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +class Cnn10(nn.Module): + def __init__( + self, + sample_rate, + window_size, + hop_size, + mel_bins, + fmin, + fmax, + classes_num, + enable_fusion=False, + fusion_type="None", + ): + + super(Cnn10, self).__init__() + + window = "hann" + center = True + pad_mode = "reflect" + ref = 1.0 + amin = 1e-10 + top_db = None + + self.enable_fusion = enable_fusion + self.fusion_type = fusion_type + + # Spectrogram extractor + self.spectrogram_extractor = Spectrogram( + n_fft=window_size, + hop_length=hop_size, + win_length=window_size, + window=window, + center=center, + pad_mode=pad_mode, + freeze_parameters=True, + ) + + # Logmel feature extractor + self.logmel_extractor = LogmelFilterBank( + sr=sample_rate, + n_fft=window_size, + n_mels=mel_bins, + fmin=fmin, + fmax=fmax, + ref=ref, + amin=amin, + top_db=top_db, + freeze_parameters=True, + ) + + # Spec augmenter + self.spec_augmenter = SpecAugmentation( + time_drop_width=64, + time_stripes_num=2, + freq_drop_width=8, + freq_stripes_num=2, + ) + + self.bn0 = nn.BatchNorm2d(64) + + self.conv_block1 = ConvBlock(in_channels=1, out_channels=64) + self.conv_block2 = ConvBlock(in_channels=64, out_channels=128) + self.conv_block3 = ConvBlock(in_channels=128, out_channels=256) + self.conv_block4 = ConvBlock(in_channels=256, out_channels=512) + self.conv_block5 = ConvBlock(in_channels=512, out_channels=1024) + + self.fc1 = nn.Linear(1024, 1024, bias=True) + self.fc_audioset = nn.Linear(1024, classes_num, bias=True) + + self.init_weight() + + def init_weight(self): + init_bn(self.bn0) + init_layer(self.fc1) + init_layer(self.fc_audioset) + + def forward(self, input, mixup_lambda=None, device=None): + """ + Input: (batch_size, data_length)""" + + x = self.spectrogram_extractor(input) # (batch_size, 1, time_steps, freq_bins) + x = self.logmel_extractor(x) # (batch_size, 1, time_steps, mel_bins) + + x = x.transpose(1, 3) + x = self.bn0(x) + x = x.transpose(1, 3) + + if self.training: + x = self.spec_augmenter(x) + + # Mixup on spectrogram + if self.training and mixup_lambda is not None: + x = do_mixup(x, mixup_lambda) + + x = self.conv_block1(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block2(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block3(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block4(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = self.conv_block5(x, pool_size=(2, 2), pool_type="avg") + x = F.dropout(x, p=0.2, training=self.training) + x = torch.mean(x, dim=3) + + latent_x1 = F.max_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x2 = F.avg_pool1d(x, kernel_size=3, stride=1, padding=1) + latent_x = latent_x1 + latent_x2 + latent_x = latent_x.transpose(1, 2) + latent_x = F.relu_(self.fc1(latent_x)) + latent_output = interpolate(latent_x, 32) + + (x1, _) = torch.max(x, dim=2) + x2 = torch.mean(x, dim=2) + x = x1 + x2 + x = F.dropout(x, p=0.5, training=self.training) + x = F.relu_(self.fc1(x)) + embedding = F.dropout(x, p=0.5, training=self.training) + clipwise_output = torch.sigmoid(self.fc_audioset(x)) + + output_dict = { + "clipwise_output": clipwise_output, + "embedding": embedding, + "fine_grained_embedding": latent_output, + } + + return output_dict + + +def create_pann_model(audio_cfg, enable_fusion=False, fusion_type="None"): + try: + ModelProto = eval(audio_cfg.model_name) + model = ModelProto( + sample_rate=audio_cfg.sample_rate, + window_size=audio_cfg.window_size, + hop_size=audio_cfg.hop_size, + mel_bins=audio_cfg.mel_bins, + fmin=audio_cfg.fmin, + fmax=audio_cfg.fmax, + classes_num=audio_cfg.class_num, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + return model + except: + raise RuntimeError( + f"Import Model for {audio_cfg.model_name} not found, or the audio cfg parameters are not enough." + ) diff --git a/audioldm/clap/open_clip/pretrained.py b/audioldm/clap/open_clip/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..8ed2ae1732a28c4e98d1f3412157ef27054e41dc --- /dev/null +++ b/audioldm/clap/open_clip/pretrained.py @@ -0,0 +1,169 @@ +import hashlib +import os +import urllib +import warnings + +from tqdm import tqdm + +CACHE_DIR = os.getenv("AUDIOLDM_CACHE_DIR", "~/.cache") + +_RN50 = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN50_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-yfcc15m-455df137.pt", + cc12m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn50-quickgelu-cc12m-f000538c.pt", +) + +_RN101 = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN101_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", + yfcc15m="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/rn101-quickgelu-yfcc15m-3e04b30e.pt", +) + +_RN50x4 = dict( + openai="https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", +) + +_RN50x16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", +) + +_RN50x64 = dict( + openai="https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", +) + +_VITB32 = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB32_quickgelu = dict( + openai="https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", + laion400m_e31="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e31-d867053b.pt", + laion400m_e32="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_e32-46683a32.pt", + laion400m_avg="https://github.com/mlfoundations/open_clip/releases/download/v0.2-weights/vit_b_32-quickgelu-laion400m_avg-8a00ab3c.pt", +) + +_VITB16 = dict( + openai="https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", +) + +_VITL14 = dict( + openai="https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", +) + +_PRETRAINED = { + "RN50": _RN50, + "RN50-quickgelu": _RN50_quickgelu, + "RN101": _RN101, + "RN101-quickgelu": _RN101_quickgelu, + "RN50x4": _RN50x4, + "RN50x16": _RN50x16, + "ViT-B-32": _VITB32, + "ViT-B-32-quickgelu": _VITB32_quickgelu, + "ViT-B-16": _VITB16, + "ViT-L-14": _VITL14, +} + + +def list_pretrained(as_str: bool = False): + """returns list of pretrained models + Returns a tuple (model_name, pretrain_tag) by default or 'name:tag' if as_str == True + """ + return [ + ":".join([k, t]) if as_str else (k, t) + for k in _PRETRAINED.keys() + for t in _PRETRAINED[k].keys() + ] + + +def list_pretrained_tag_models(tag: str): + """return all models having the specified pretrain tag""" + models = [] + for k in _PRETRAINED.keys(): + if tag in _PRETRAINED[k]: + models.append(k) + return models + + +def list_pretrained_model_tags(model: str): + """return all pretrain tags for the specified model architecture""" + tags = [] + if model in _PRETRAINED: + tags.extend(_PRETRAINED[model].keys()) + return tags + + +def get_pretrained_url(model: str, tag: str): + if model not in _PRETRAINED: + return "" + model_pretrained = _PRETRAINED[model] + if tag not in model_pretrained: + return "" + return model_pretrained[tag] + + +def download_pretrained(url: str, root: str = os.path.expanduser(f"{CACHE_DIR}/clip")): + os.makedirs(root, exist_ok=True) + filename = os.path.basename(url) + + if "openaipublic" in url: + expected_sha256 = url.split("/")[-2] + else: + expected_sha256 = "" + + download_target = os.path.join(root, filename) + + if os.path.exists(download_target) and not os.path.isfile(download_target): + raise RuntimeError(f"{download_target} exists and is not a regular file") + + if os.path.isfile(download_target): + if expected_sha256: + if ( + hashlib.sha256(open(download_target, "rb").read()).hexdigest() + == expected_sha256 + ): + return download_target + else: + warnings.warn( + f"{download_target} exists, but the SHA256 checksum does not match; re-downloading the file" + ) + else: + return download_target + + with urllib.request.urlopen(url) as source, open(download_target, "wb") as output: + with tqdm( + total=int(source.info().get("Content-Length")), + ncols=80, + unit="iB", + unit_scale=True, + ) as loop: + while True: + buffer = source.read(8192) + if not buffer: + break + + output.write(buffer) + loop.update(len(buffer)) + + if ( + expected_sha256 + and hashlib.sha256(open(download_target, "rb").read()).hexdigest() + != expected_sha256 + ): + raise RuntimeError( + f"Model has been downloaded but the SHA256 checksum does not not match" + ) + + return download_target diff --git a/audioldm/clap/open_clip/timm_model.py b/audioldm/clap/open_clip/timm_model.py new file mode 100644 index 0000000000000000000000000000000000000000..c9d1ab4666b5bab5038d44b90c9ddca5087de460 --- /dev/null +++ b/audioldm/clap/open_clip/timm_model.py @@ -0,0 +1,112 @@ +""" timm model adapter + +Wraps timm (https://github.com/rwightman/pytorch-image-models) models for use as a vision tower in CLIP model. +""" +from collections import OrderedDict + +import torch.nn as nn + +try: + import timm + from timm.models.layers import Mlp, to_2tuple + from timm.models.layers.attention_pool2d import RotAttentionPool2d + from timm.models.layers.attention_pool2d import ( + AttentionPool2d as AbsAttentionPool2d, + ) +except ImportError as e: + timm = None + +from .utils import freeze_batch_norm_2d + + +class TimmModel(nn.Module): + """timm model adapter + # FIXME this adapter is a work in progress, may change in ways that break weight compat + """ + + def __init__( + self, + model_name, + embed_dim, + image_size=224, + pool="avg", + proj="linear", + drop=0.0, + pretrained=False, + ): + super().__init__() + if timm is None: + raise RuntimeError("Please `pip install timm` to use timm models.") + + self.image_size = to_2tuple(image_size) + self.trunk = timm.create_model(model_name, pretrained=pretrained) + feat_size = self.trunk.default_cfg.get("pool_size", None) + feature_ndim = 1 if not feat_size else 2 + if pool in ("abs_attn", "rot_attn"): + assert feature_ndim == 2 + # if attn pooling used, remove both classifier and default pool + self.trunk.reset_classifier(0, global_pool="") + else: + # reset global pool if pool config set, otherwise leave as network default + reset_kwargs = dict(global_pool=pool) if pool else {} + self.trunk.reset_classifier(0, **reset_kwargs) + prev_chs = self.trunk.num_features + + head_layers = OrderedDict() + if pool == "abs_attn": + head_layers["pool"] = AbsAttentionPool2d( + prev_chs, feat_size=feat_size, out_features=embed_dim + ) + prev_chs = embed_dim + elif pool == "rot_attn": + head_layers["pool"] = RotAttentionPool2d(prev_chs, out_features=embed_dim) + prev_chs = embed_dim + else: + assert proj, "projection layer needed if non-attention pooling is used." + + # NOTE attention pool ends with a projection layer, so proj should usually be set to '' if such pooling is used + if proj == "linear": + head_layers["drop"] = nn.Dropout(drop) + head_layers["proj"] = nn.Linear(prev_chs, embed_dim) + elif proj == "mlp": + head_layers["mlp"] = Mlp(prev_chs, 2 * embed_dim, embed_dim, drop=drop) + + self.head = nn.Sequential(head_layers) + + def lock(self, unlocked_groups=0, freeze_bn_stats=False): + """lock modules + Args: + unlocked_groups (int): leave last n layer groups unlocked (default: 0) + """ + if not unlocked_groups: + # lock full model + for param in self.trunk.parameters(): + param.requires_grad = False + if freeze_bn_stats: + freeze_batch_norm_2d(self.trunk) + else: + # NOTE: partial freeze requires latest timm (master) branch and is subject to change + try: + # FIXME import here until API stable and in an official release + from timm.models.helpers import group_parameters, group_modules + except ImportError: + raise RuntimeError( + "Please install latest timm `pip install git+https://github.com/rwightman/pytorch-image-models`" + ) + matcher = self.trunk.group_matcher() + gparams = group_parameters(self.trunk, matcher) + max_layer_id = max(gparams.keys()) + max_layer_id = max_layer_id - unlocked_groups + for group_idx in range(max_layer_id + 1): + group = gparams[group_idx] + for param in group: + self.trunk.get_parameter(param).requires_grad = False + if freeze_bn_stats: + gmodules = group_modules(self.trunk, matcher, reverse=True) + gmodules = {k for k, v in gmodules.items() if v <= max_layer_id} + freeze_batch_norm_2d(self.trunk, gmodules) + + def forward(self, x): + x = self.trunk(x) + x = self.head(x) + return x diff --git a/audioldm/clap/open_clip/tokenizer.py b/audioldm/clap/open_clip/tokenizer.py new file mode 100644 index 0000000000000000000000000000000000000000..ee4d28450ec5dd12a79daf38cf3088e9e73c2cd5 --- /dev/null +++ b/audioldm/clap/open_clip/tokenizer.py @@ -0,0 +1,197 @@ +""" CLIP tokenizer + +Copied from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI. +""" +import gzip +import html +import os +from functools import lru_cache +from typing import Union, List + +import ftfy +import regex as re +import torch + + +@lru_cache() +def default_bpe(): + return os.path.join( + os.path.dirname(os.path.abspath(__file__)), "bpe_simple_vocab_16e6.txt.gz" + ) + + +@lru_cache() +def bytes_to_unicode(): + """ + Returns list of utf-8 byte and a corresponding list of unicode strings. + The reversible bpe codes work on unicode strings. + This means you need a large # of unicode characters in your vocab if you want to avoid UNKs. + When you're at something like a 10B token dataset you end up needing around 5K for decent coverage. + This is a signficant percentage of your normal, say, 32K bpe vocab. + To avoid that, we want lookup tables between utf-8 bytes and unicode strings. + And avoids mapping to whitespace/control characters the bpe code barfs on. + """ + bs = ( + list(range(ord("!"), ord("~") + 1)) + + list(range(ord("¡"), ord("¬") + 1)) + + list(range(ord("®"), ord("ÿ") + 1)) + ) + cs = bs[:] + n = 0 + for b in range(2**8): + if b not in bs: + bs.append(b) + cs.append(2**8 + n) + n += 1 + cs = [chr(n) for n in cs] + return dict(zip(bs, cs)) + + +def get_pairs(word): + """Return set of symbol pairs in a word. + Word is represented as tuple of symbols (symbols being variable-length strings). + """ + pairs = set() + prev_char = word[0] + for char in word[1:]: + pairs.add((prev_char, char)) + prev_char = char + return pairs + + +def basic_clean(text): + text = ftfy.fix_text(text) + text = html.unescape(html.unescape(text)) + return text.strip() + + +def whitespace_clean(text): + text = re.sub(r"\s+", " ", text) + text = text.strip() + return text + + +class SimpleTokenizer(object): + def __init__(self, bpe_path: str = default_bpe(), special_tokens=None): + self.byte_encoder = bytes_to_unicode() + self.byte_decoder = {v: k for k, v in self.byte_encoder.items()} + merges = gzip.open(bpe_path).read().decode("utf-8").split("\n") + merges = merges[1 : 49152 - 256 - 2 + 1] + merges = [tuple(merge.split()) for merge in merges] + vocab = list(bytes_to_unicode().values()) + vocab = vocab + [v + "" for v in vocab] + for merge in merges: + vocab.append("".join(merge)) + if not special_tokens: + special_tokens = ["", ""] + else: + special_tokens = ["", ""] + special_tokens + vocab.extend(special_tokens) + self.encoder = dict(zip(vocab, range(len(vocab)))) + self.decoder = {v: k for k, v in self.encoder.items()} + self.bpe_ranks = dict(zip(merges, range(len(merges)))) + self.cache = {t: t for t in special_tokens} + special = "|".join(special_tokens) + self.pat = re.compile( + special + r"""|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+""", + re.IGNORECASE, + ) + + self.vocab_size = len(self.encoder) + self.all_special_ids = [self.encoder[t] for t in special_tokens] + + def bpe(self, token): + if token in self.cache: + return self.cache[token] + word = tuple(token[:-1]) + (token[-1] + "",) + pairs = get_pairs(word) + + if not pairs: + return token + "" + + while True: + bigram = min(pairs, key=lambda pair: self.bpe_ranks.get(pair, float("inf"))) + if bigram not in self.bpe_ranks: + break + first, second = bigram + new_word = [] + i = 0 + while i < len(word): + try: + j = word.index(first, i) + new_word.extend(word[i:j]) + i = j + except: + new_word.extend(word[i:]) + break + + if word[i] == first and i < len(word) - 1 and word[i + 1] == second: + new_word.append(first + second) + i += 2 + else: + new_word.append(word[i]) + i += 1 + new_word = tuple(new_word) + word = new_word + if len(word) == 1: + break + else: + pairs = get_pairs(word) + word = " ".join(word) + self.cache[token] = word + return word + + def encode(self, text): + bpe_tokens = [] + text = whitespace_clean(basic_clean(text)).lower() + for token in re.findall(self.pat, text): + token = "".join(self.byte_encoder[b] for b in token.encode("utf-8")) + bpe_tokens.extend( + self.encoder[bpe_token] for bpe_token in self.bpe(token).split(" ") + ) + return bpe_tokens + + def decode(self, tokens): + text = "".join([self.decoder[token] for token in tokens]) + text = ( + bytearray([self.byte_decoder[c] for c in text]) + .decode("utf-8", errors="replace") + .replace("", " ") + ) + return text + + +_tokenizer = SimpleTokenizer() + + +def tokenize( + texts: Union[str, List[str]], context_length: int = 77 +) -> torch.LongTensor: + """ + Returns the tokenized representation of given input string(s) + + Parameters + ---------- + texts : Union[str, List[str]] + An input string or a list of input strings to tokenize + context_length : int + The context length to use; all CLIP models use 77 as the context length + + Returns + ------- + A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length] + """ + if isinstance(texts, str): + texts = [texts] + + sot_token = _tokenizer.encoder[""] + eot_token = _tokenizer.encoder[""] + all_tokens = [[sot_token] + _tokenizer.encode(text) + [eot_token] for text in texts] + result = torch.zeros(len(all_tokens), context_length, dtype=torch.long) + + for i, tokens in enumerate(all_tokens): + if len(tokens) > context_length: + tokens = tokens[:context_length] # Truncate + result[i, : len(tokens)] = torch.tensor(tokens) + + return result diff --git a/audioldm/clap/open_clip/transform.py b/audioldm/clap/open_clip/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..77aaa722c4a5544ac50de6df35d3e922f63b111d --- /dev/null +++ b/audioldm/clap/open_clip/transform.py @@ -0,0 +1,45 @@ +from torchvision.transforms import ( + Normalize, + Compose, + RandomResizedCrop, + InterpolationMode, + ToTensor, + Resize, + CenterCrop, +) + + +def _convert_to_rgb(image): + return image.convert("RGB") + + +def image_transform( + image_size: int, + is_train: bool, + mean=(0.48145466, 0.4578275, 0.40821073), + std=(0.26862954, 0.26130258, 0.27577711), +): + normalize = Normalize(mean=mean, std=std) + if is_train: + return Compose( + [ + RandomResizedCrop( + image_size, + scale=(0.9, 1.0), + interpolation=InterpolationMode.BICUBIC, + ), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) + else: + return Compose( + [ + Resize(image_size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(image_size), + _convert_to_rgb, + ToTensor(), + normalize, + ] + ) diff --git a/audioldm/clap/open_clip/utils.py b/audioldm/clap/open_clip/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..34ecbced4cb7e6b6f92154a666e2c7efc7c922c6 --- /dev/null +++ b/audioldm/clap/open_clip/utils.py @@ -0,0 +1,362 @@ +import numpy as np +import torch +from torch import nn as nn +from torchvision.ops.misc import FrozenBatchNorm2d +import logging + +# import h5py +from tqdm import tqdm +import random +import json +import os +import pathlib + +# TODO: (yusong) this not a good place to store those information and does not scale. Need to be fixed later. +dataset_split = { + "audiocaps": ["train", "valid", "test"], + "audioset": ["balanced_train", "unbalanced_train", "eval"], + "BBCSoundEffects": ["train", "test"], + "Clotho": ["train", "test", "valid"], + "free_to_use_sounds": ["train", "test"], + "paramount_motion": ["train", "test"], + "sonniss_game_effects": ["train", "test"], + "wesoundeffects": ["train", "test"], + "MACS": ["train", "test"], + "freesound": ["train", "test"], + "FSD50K": ["train", "test", "valid"], + "fsd50k_class_label": ["train", "test", "valid"], + "esc50": ["train", "test"], + "audiostock": ["train", "test"], + "freesound_no_overlap_noesc50": ["train", "test"], + "epidemic_sound_effects": ["train", "test"], + "VGGSound": ["train", "test"], + "urbansound8k_class_label": ["train", "test"], + "audioset_t5": ["balanced_train", "unbalanced_train", "eval"], + "epidemic_sound_effects_t5": ["train", "test"], + "WavText5K": ["train", "test"], + "esc50_no_overlap": ["train", "test"], + "usd8k_no_overlap": ["train", "test"], + "fsd50k_200_class_label": ["train", "test", "valid"], +} + + +def freeze_batch_norm_2d(module, module_match={}, name=""): + """ + Converts all `BatchNorm2d` and `SyncBatchNorm` layers of provided module into `FrozenBatchNorm2d`. If `module` is + itself an instance of either `BatchNorm2d` or `SyncBatchNorm`, it is converted into `FrozenBatchNorm2d` and + returned. Otherwise, the module is walked recursively and submodules are converted in place. + + Args: + module (torch.nn.Module): Any PyTorch module. + module_match (dict): Dictionary of full module names to freeze (all if empty) + name (str): Full module name (prefix) + + Returns: + torch.nn.Module: Resulting module + + Inspired by https://github.com/pytorch/pytorch/blob/a5895f85be0f10212791145bfedc0261d364f103/torch/nn/modules/batchnorm.py#L762 + """ + res = module + is_match = True + if module_match: + is_match = name in module_match + if is_match and isinstance( + module, (nn.modules.batchnorm.BatchNorm2d, nn.modules.batchnorm.SyncBatchNorm) + ): + res = FrozenBatchNorm2d(module.num_features) + res.num_features = module.num_features + res.affine = module.affine + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for child_name, child in module.named_children(): + full_child_name = ".".join([name, child_name]) if name else child_name + new_child = freeze_batch_norm_2d(child, module_match, full_child_name) + if new_child is not child: + res.add_module(child_name, new_child) + return res + + +def exist(dataset_name, dataset_type): + """ + Check if dataset exists + """ + if dataset_type in dataset_split[dataset_name]: + return True + else: + return False + + +def get_tar_path_from_dataset_name( + dataset_names, dataset_types, islocal, dataset_path, proportion=1, full_dataset=None +): + """ + Get tar path from dataset name and type + """ + output = [] + for n in dataset_names: + if full_dataset is not None and n in full_dataset: + current_dataset_types = dataset_split[n] + else: + current_dataset_types = dataset_types + for s in current_dataset_types: + tmp = [] + if islocal: + sizefilepath_ = f"{dataset_path}/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + else: + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + if not os.path.exists(sizefilepath_): + continue + sizes = json.load(open(sizefilepath_, "r")) + for k in sizes.keys(): + if islocal: + tmp.append(f"{dataset_path}/{n}/{s}/{k}") + else: + tmp.append( + f"pipe:aws s3 --cli-connect-timeout 0 cp s3://s-laion-audio/webdataset_tar/{n}/{s}/{k} -" + ) + if proportion != 1: + tmp = random.sample(tmp, int(proportion * len(tmp))) + output.append(tmp) + return sum(output, []) + + +def get_tar_path_from_txts(txt_path, islocal, proportion=1): + """ + Get tar path from txt path + """ + if isinstance(txt_path, (list, tuple)): + return sum( + [ + get_tar_path_from_txts( + txt_path[i], islocal=islocal, proportion=proportion + ) + for i in range(len(txt_path)) + ], + [], + ) + if isinstance(txt_path, str): + with open(txt_path) as f: + lines = f.readlines() + if islocal: + lines = [ + lines[i] + .split("\n")[0] + .replace("pipe:aws s3 cp s3://s-laion-audio/", "/mnt/audio_clip/") + for i in range(len(lines)) + ] + else: + lines = [ + lines[i].split("\n")[0].replace(".tar", ".tar -") + for i in range(len(lines)) + ] + if proportion != 1: + print("Sampling tars with proportion of {}".format(proportion)) + lines = random.sample(lines, int(proportion * len(lines))) + return lines + + +def get_mix_lambda(mixup_alpha, batch_size): + mixup_lambdas = [ + np.random.beta(mixup_alpha, mixup_alpha, 1)[0] for _ in range(batch_size) + ] + return np.array(mixup_lambdas).astype(np.float32) + + +def do_mixup(x, mixup_lambda): + """ + Args: + x: (batch_size , ...) + mixup_lambda: (batch_size,) + Returns: + out: (batch_size, ...) + """ + out = ( + x.transpose(0, -1) * mixup_lambda + + torch.flip(x, dims=[0]).transpose(0, -1) * (1 - mixup_lambda) + ).transpose(0, -1) + return out + + +def interpolate(x, ratio): + """Interpolate data in time domain. This is used to compensate the + resolution reduction in downsampling of a CNN. + + Args: + x: (batch_size, time_steps, classes_num) + ratio: int, ratio to interpolate + Returns: + upsampled: (batch_size, time_steps * ratio, classes_num) + """ + (batch_size, time_steps, classes_num) = x.shape + upsampled = x[:, :, None, :].repeat(1, 1, ratio, 1) + upsampled = upsampled.reshape(batch_size, time_steps * ratio, classes_num) + return upsampled + + +def pad_framewise_output(framewise_output, frames_num): + """Pad framewise_output to the same length as input frames. The pad value + is the same as the value of the last frame. + Args: + framewise_output: (batch_size, frames_num, classes_num) + frames_num: int, number of frames to pad + Outputs: + output: (batch_size, frames_num, classes_num) + """ + pad = framewise_output[:, -1:, :].repeat( + 1, frames_num - framewise_output.shape[1], 1 + ) + """tensor for padding""" + + output = torch.cat((framewise_output, pad), dim=1) + """(batch_size, frames_num, classes_num)""" + + +# def process_ipc(index_path, classes_num, filename): +# # load data +# logging.info("Load Data...............") +# ipc = [[] for _ in range(classes_num)] +# with h5py.File(index_path, "r") as f: +# for i in tqdm(range(len(f["target"]))): +# t_class = np.where(f["target"][i])[0] +# for t in t_class: +# ipc[t].append(i) +# print(ipc) +# np.save(filename, ipc) +# logging.info("Load Data Succeed...............") + + +def save_to_dict(s, o_={}): + sp = s.split(": ") + o_.update({sp[0]: float(sp[1])}) + return o_ + + +def get_data_from_log(txt_path): + """ + Output dictionary from out.txt log file + """ + with open(txt_path) as f: + lines = f.readlines() + val_data = {} + train_data = {} + train_losses = [] + train_losses_epoch = [] + for i in range(len(lines)): + if "| INFO |" in lines[i]: + if "Eval Epoch" in lines[i]: + if "val_loss" in lines[i]: + # float(regex.sub("", lines[310].split(" ")[-1]).replace(" ", "")) + line = lines[i].split("Eval Epoch: ")[-1] + num_epoch = int(line.split(" ")[0].split(" ")[0]) + d = { + line.split(" ")[0] + .split(" ")[1] + .replace(":", ""): float(line.split(" ")[0].split(" ")[-1]) + } + for i in range(1, len(line.split(" "))): + d = save_to_dict(line.split(" ")[i], d) + val_data[num_epoch] = d + elif "Train Epoch" in lines[i]: + num_epoch = int(lines[i].split("Train Epoch: ")[1][0]) + loss = float(lines[i].split("Loss: ")[-1].split(" (")[0]) + train_losses.append(loss) + train_losses_epoch.append(num_epoch) + for i in range(len(train_losses)): + train_data[i] = { + "num_epoch": train_losses_epoch[i], + "train_loss": train_losses[i], + } + return train_data, val_data + + +def save_p(obj, filename): + import pickle + + try: + from deepdiff import DeepDiff + except: + os.system("pip install deepdiff") + from deepdiff import DeepDiff + with open(filename, "wb") as file: + pickle.dump(obj, file, protocol=pickle.HIGHEST_PROTOCOL) # highest protocol + with open(filename, "rb") as file: + z = pickle.load(file) + assert ( + DeepDiff(obj, z, ignore_string_case=True) == {} + ), "there is something wrong with the saving process" + return + + +def load_p(filename): + import pickle + + with open(filename, "rb") as file: + z = pickle.load(file) + return z + + +def save_json(data, name="data.json"): + import json + + with open(name, "w") as fp: + json.dump(data, fp) + return + + +def load_json(name): + import json + + with open(name, "r") as fp: + data = json.load(fp) + return data + + +from multiprocessing import Process, Manager +from multiprocessing import Process, Value, Array +from ctypes import c_wchar + + +def load_class_label(path): + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + out = None + if path is not None: + if pathlib.Path(path).suffix in [".pkl", ".pickle"]: + out = load_p(path) + elif pathlib.Path(path).suffix in [".json", ".txt"]: + out = load_json(path) + elif pathlib.Path(path).suffix in [".npy", ".npz"]: + out = np.load(path) + elif pathlib.Path(path).suffix in [".csv"]: + import pandas as pd + + out = pd.read_csv(path) + return out + # if out is None: + # return None + # else: + # key = Array(c_wchar, '\n'.join(list(out.keys())), lock=False) + # val = Array('i', out.values(), lock=False) + # return (key, val) + + +from torch import optim + + +def get_optimizer(params, lr, betas, eps, momentum, optimizer_name): + if optimizer_name.lower() == "adamw": + optimizer = optim.AdamW(params, lr=lr, betas=betas, eps=eps) + elif optimizer_name.lower() == "sgd": + optimizer = optim.SGD(params, lr=lr, momentum=momentum) + elif optimizer_name.lower() == "adam": + optimizer = optim.Adam(params, lr=lr, betas=betas, eps=eps) + else: + raise ValueError("optimizer name is not correct") + return optimizer diff --git a/audioldm/clap/open_clip/version.py b/audioldm/clap/open_clip/version.py new file mode 100644 index 0000000000000000000000000000000000000000..3ced3581bb601ae91b1e1da4b8f4f520855a065e --- /dev/null +++ b/audioldm/clap/open_clip/version.py @@ -0,0 +1 @@ +__version__ = "0.2.1" diff --git a/audioldm/clap/training/__init__.py b/audioldm/clap/training/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/clap/training/audioset_textmap.npy b/audioldm/clap/training/audioset_textmap.npy new file mode 100644 index 0000000000000000000000000000000000000000..3da4c92d3819aaec11e5f576464a9973a6df811b --- /dev/null +++ b/audioldm/clap/training/audioset_textmap.npy @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bada103070d92f9eadd33e1b4f45ec8583f59080ef218c966b43294bd4c86d5b +size 84448 diff --git a/audioldm/clap/training/data.py b/audioldm/clap/training/data.py new file mode 100644 index 0000000000000000000000000000000000000000..a005fee2f51e577446839b8cffd117d9ae93abc9 --- /dev/null +++ b/audioldm/clap/training/data.py @@ -0,0 +1,981 @@ +import ast +import json +import logging +import math +import os +import random + +# import h5py +from dataclasses import dataclass +from audioldm.clap.training.params import parse_args + +# import braceexpand +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.datasets as datasets +import torchvision.transforms + +# import webdataset as wds +from PIL import Image +from torch.utils.data import Dataset, DataLoader, SubsetRandomSampler +from torch.utils.data.distributed import DistributedSampler +from functools import partial +import soundfile as sf +import io +from pathlib import Path + +# import wget + +from audioldm.clap.open_clip.utils import ( + get_tar_path_from_dataset_name, + dataset_split, +) +from audioldm.clap.open_clip.utils import load_p, load_class_label +import copy + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +try: + import torchaudio +except ImportError: + torchaudio = None + +from audioldm.clap.open_clip import tokenize + + +def tokenizer(text): + return tokenize(text).squeeze(0) + + +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +# initizlied the audioset map +_AUDIOSET_MAP_PATH = os.path.join(Path(__file__).parent, "audioset_textmap.npy") +_AUDIOSET_MAP = np.load(_AUDIOSET_MAP_PATH, allow_pickle=True) + + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + + +def float32_to_int16(x): + x = np.clip(x, a_min=-1.0, a_max=1.0) + return (x * 32767.0).astype(np.int16) + + +# For Toy Dataset +# class ToyDataset(Dataset): +# def __init__(self, index_path, ipc, config, eval_mode=False): +# """Toy Dataset for testing the audioset input with text labels +# Parameters +# ---------- +# index_path: str +# the link to the h5 file of each audio +# idc: str +# the link to the npy file, the number of samples in each class +# config: dict +# the audio cfg file +# eval_model (bool): to indicate if the dataset is a testing dataset +# """ +# self.audio_cfg = config["audio_cfg"] +# self.text_cfg = config["text_cfg"] +# self.fp = h5py.File(index_path, "r") +# self.ipc = np.load(ipc, allow_pickle=True) +# self.total_size = len(self.fp["audio_name"]) +# self.classes_num = self.audio_cfg["class_num"] +# self.eval_mode = eval_mode + +# if not eval_mode: +# self.generate_queue() +# else: +# self.queue = [] +# for i in range(self.total_size): +# target = self.fp["target"][i] +# if np.sum(target) > 0: +# self.queue.append(i) +# self.total_size = len(self.queue) +# logging.info("total dataset size: %d" % (self.total_size)) +# logging.info("class num: %d" % (self.classes_num)) + +# def time_shifting(self, x): +# frame_num = len(x) +# shift_len = random.randint(0, frame_num - 1) +# new_sample = np.concatenate([x[shift_len:], x[:shift_len]], axis=0) +# return new_sample + +# def generate_queue(self): +# self.queue = [] +# while len(self.queue) < self.total_size: +# class_set = [*range(self.classes_num)] +# random.shuffle(class_set) +# self.queue += [ +# self.ipc[d][random.randint(0, len(self.ipc[d]) - 1)] for d in class_set +# ] +# self.queue = self.queue[: self.total_size] + +# logging.info("queue regenerated:%s" % (self.queue[-5:])) + +# def crop_wav(self, x): +# crop_size = self.audio_cfg["crop_size"] +# crop_pos = random.randint(0, len(x) - crop_size - 1) +# return x[crop_pos : crop_pos + crop_size] + +# def prompt_text(self, target): +# events = _AUDIOSET_MAP[np.where(target > 0)] +# event_text = "The sounds of " + ", ".join(events[:-1]) + " and " + events[-1] +# text = tokenize(event_text)[0] +# return text + +# def __getitem__(self, index): +# """Load waveform, text, and target of an audio clip + +# Parameters +# ---------- +# index: int +# the index number +# Return +# ------ +# output: dict { +# "hdf5_path": str, +# "index_in_hdf5": int, +# "audio_name": str, +# "waveform": list (audio_length,), +# "target": list (class_num, ), +# "text": torch.tensor (context_length,) +# } +# the output dictionary +# """ +# s_index = self.queue[index] + +# audio_name = self.fp["audio_name"][s_index].decode() +# # Hardcode here CHANGE +# hdf5_path = ( +# self.fp["hdf5_path"][s_index] +# .decode() +# .replace( +# "../workspace", +# "/home/la/kechen/Research/ke_zsasp/workspace", +# ) +# ) +# r_idx = self.fp["index_in_hdf5"][s_index] +# target = self.fp["target"][s_index].astype(np.float32) +# text = self.prompt_text(target) +# with h5py.File(hdf5_path, "r") as f: +# waveform = int16_to_float32(f["waveform"][r_idx])[ +# : self.audio_cfg["clip_samples"] +# ] +# assert ( +# len(waveform) == self.audio_cfg["clip_samples"] +# ), "The sample length is not match" +# # Time shift +# # if (self.config.enable_time_shift) and (not self.eval_mode): +# # waveform = self.time_shifting(waveform) +# # # Label Enhance +# # if (self.config.crop_size is not None) and (not self.eval_mode): +# # waveform = self.crop_wav(waveform) +# # # the label enhance rate is fixed 0.5 +# # if (self.config.enable_label_enhance) and (not self.eval_mode) and random.random() < 0.5: +# # kidx = np.where(target)[0] +# # for k in kidx: +# # for add_key in self.class_map[k][1]: +# # target[add_key] = 1.0 +# # if len(self.class_map[k][2]) > 0: +# # add_key = random.choice(self.class_map[k][2]) +# # target[add_key] = 1.0 + +# # missing the text input +# mel_spec = get_mel(torch.from_numpy(waveform), self.audio_cfg)[None, :, :] +# mel_spec = ( +# torch.cat( +# [mel_spec, mel_spec.clone(), mel_spec.clone(), mel_spec.clone()], dim=0 +# ) +# .cpu() +# .numpy() +# ) +# longer = random.choice([True, False]) +# if longer == False: +# mel_spec[1:, :, :] = 0.0 +# data_dict = { +# "hdf5_path": hdf5_path, +# "index_in_hdf5": r_idx, +# "audio_name": audio_name, +# "waveform": waveform, +# "class_label": target, +# "text": text, +# "longer": longer, +# "mel_fusion": mel_spec, +# } +# return data_dict + +# def __len__(self): +# return self.total_size + + +class CsvDataset(Dataset): + def __init__(self, input_filename, transforms, img_key, caption_key, sep="\t"): + logging.debug(f"Loading csv data from {input_filename}.") + df = pd.read_csv(input_filename, sep=sep) + + self.images = df[img_key].tolist() + self.captions = df[caption_key].tolist() + self.transforms = transforms + logging.debug("Done loading data.") + + def __len__(self): + return len(self.captions) + + def __getitem__(self, idx): + images = self.transforms(Image.open(str(self.images[idx]))) + texts = tokenize([str(self.captions[idx])])[0] + return images, texts + + +@dataclass +class DataInfo: + dataloader: DataLoader + sampler: DistributedSampler + + +def preprocess_txt(text): + return tokenize([str(text)])[0] + + +def get_dataset_size(shards, sizefilepath_=None, is_local=True): + if isinstance(shards, list): + size_list = [] + for s in shards: + size_list.append( + get_dataset_size(s, sizefilepath_=sizefilepath_, is_local=is_local)[0] + ) + else: + if not is_local: + for n in dataset_split.keys(): + if n in shards.split("/"): + break + for s in dataset_split[n]: + if s in shards.split("/"): + break + sizefilepath_ = f"./json_files/{n}/{s}/sizes.json" + shards_list = list(braceexpand.braceexpand(shards)) + dir_path = os.path.dirname(shards) + if sizefilepath_ is not None: + sizes = json.load(open(sizefilepath_, "r")) + total_size = sum( + [ + int(sizes[os.path.basename(shard.replace(".tar -", ".tar"))]) + for shard in shards_list + ] + ) + else: + sizes_filename = os.path.join(dir_path, "sizes.json") + len_filename = os.path.join(dir_path, "__len__") + if os.path.exists(sizes_filename): + sizes = json.load(open(sizes_filename, "r")) + total_size = sum( + [int(sizes[os.path.basename(shard)]) for shard in shards_list] + ) + elif os.path.exists(len_filename): + # FIXME this used to be eval(open(...)) but that seemed rather unsafe + total_size = ast.literal_eval(open(len_filename, "r").read()) + else: + raise Exception( + "Cannot find sizes file for dataset. Please specify the path to the file." + ) + # total_size = None # num samples undefined + # some common dataset sizes (at time of authors last download) + # cc3m-train: 2905954 + # cc12m: 10968539 + # LAION-400m: 407332084 + num_shards = len(shards_list) + if isinstance(shards, list): + return sum(size_list), len(shards) + else: + return total_size, num_shards + + +def get_imagenet(args, preprocess_fns, split): + assert split in ["train", "val", "v2"] + is_train = split == "train" + preprocess_train, preprocess_val = preprocess_fns + + if split == "v2": + from imagenetv2_pytorch import ImageNetV2Dataset + + dataset = ImageNetV2Dataset(location=args.imagenet_v2, transform=preprocess_val) + else: + if is_train: + data_path = args.imagenet_train + preprocess_fn = preprocess_train + else: + data_path = args.imagenet_val + preprocess_fn = preprocess_val + assert data_path + + dataset = datasets.ImageFolder(data_path, transform=preprocess_fn) + + if is_train: + idxs = np.zeros(len(dataset.targets)) + target_array = np.array(dataset.targets) + k = 50 + for c in range(1000): + m = target_array == c + n = len(idxs[m]) + arr = np.zeros(n) + arr[:k] = 1 + np.random.shuffle(arr) + idxs[m] = arr + + idxs = idxs.astype("int") + sampler = SubsetRandomSampler(np.where(idxs)[0]) + else: + sampler = None + + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=args.batch_size, + num_workers=args.workers, + sampler=sampler, + ) + + return DataInfo(dataloader, sampler) + + +def count_samples(dataloader): + os.environ["WDS_EPOCH"] = "0" + n_elements, n_batches = 0, 0 + for images, texts in dataloader: + n_batches += 1 + n_elements += len(images) + assert len(images) == len(texts) + return n_elements, n_batches + + +def filter_no_caption(sample): + return "txt" in sample + + +def log_and_continue(exn): + """Call in an exception handler to ignore any exception, isssue a warning, and continue.""" + logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") + return True + + +_SHARD_SHUFFLE_SIZE = 2000 +_SHARD_SHUFFLE_INITIAL = 500 +_SAMPLE_SHUFFLE_SIZE = 5000 +_SAMPLE_SHUFFLE_INITIAL = 1000 + + +def sample_prop(sizefile, inputs, proportion, is_local=True): + """ + Sample a proportion of the data. + """ + file_path_dict = { + os.path.split(inputs[i])[1]: os.path.split(inputs[i])[0] + for i in range(len(inputs)) + } + sampled_filepath_dict = {} + sampled_size_dict = {} + if not is_local: + if os.path.exists("sizes.json"): + os.remove("sizes.json") + wget.download(sizefile, "sizes.json") + sizefile = "sizes.json" + with open(sizefile, "r", encoding="UTF-8") as f: + load_dict = json.load(f) + L = int(len(file_path_dict) * proportion) + subkeys = random.sample(file_path_dict.keys(), L) + for k in subkeys: + sampled_size_dict[k] = load_dict[k] + sampled_filepath_dict[k] = file_path_dict[k] + return ( + sum(sampled_size_dict.values()), + L, + [os.path.join(v, k) for k, v in sampled_filepath_dict.items()], + sampled_size_dict, + ) + + +def get_mel(audio_data, audio_cfg): + # mel shape: (n_mels, T) + mel = torchaudio.transforms.MelSpectrogram( + sample_rate=audio_cfg["sample_rate"], + n_fft=audio_cfg["window_size"], + win_length=audio_cfg["window_size"], + hop_length=audio_cfg["hop_size"], + center=True, + pad_mode="reflect", + power=2.0, + norm=None, + onesided=True, + n_mels=64, + f_min=audio_cfg["fmin"], + f_max=audio_cfg["fmax"], + ).to(audio_data.device) + mel = mel(audio_data) + # Align to librosa: + # librosa_melspec = librosa.feature.melspectrogram( + # waveform, + # sr=audio_cfg['sample_rate'], + # n_fft=audio_cfg['window_size'], + # hop_length=audio_cfg['hop_size'], + # win_length=audio_cfg['window_size'], + # center=True, + # pad_mode="reflect", + # power=2.0, + # n_mels=64, + # norm=None, + # htk=True, + # f_min=audio_cfg['fmin'], + # f_max=audio_cfg['fmax'] + # ) + # we use log mel spectrogram as input + mel = torchaudio.transforms.AmplitudeToDB(top_db=None)(mel) + return mel.T # (T, n_mels) + + +def get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg +): + """ + Calculate and add audio features to sample. + Sample: a dict containing all the data of current sample. + audio_data: a tensor of shape (T) containing audio data. + max_len: the maximum length of audio data. + data_truncating: the method of truncating data. + data_filling: the method of filling data. + audio_cfg: a dict containing audio configuration. Comes from model_cfg['audio_cfg']. + """ + with torch.no_grad(): + if len(audio_data) > max_len: + if data_truncating == "rand_trunc": + longer = torch.tensor([True]) + elif data_truncating == "fusion": + # fusion + mel = get_mel(audio_data, audio_cfg) + # split to three parts + chunk_frames = ( + max_len // audio_cfg["hop_size"] + 1 + ) # the +1 related to how the spectrogram is computed + total_frames = mel.shape[0] + if chunk_frames == total_frames: + # there is a corner case where the audio length is + # larger than max_len but smaller than max_len+hop_size. + # In this case, we just use the whole audio. + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + else: + ranges = np.array_split( + list(range(0, total_frames - chunk_frames + 1)), 3 + ) + # print('total_frames-chunk_frames:', total_frames-chunk_frames, + # 'len(audio_data):', len(audio_data), + # 'chunk_frames:', chunk_frames, + # 'total_frames:', total_frames) + if len(ranges[1]) == 0: + # if the audio is too short, we just use the first chunk + ranges[1] = [0] + if len(ranges[2]) == 0: + # if the audio is too short, we just use the first chunk + ranges[2] = [0] + # randomly choose index for each part + idx_front = np.random.choice(ranges[0]) + idx_middle = np.random.choice(ranges[1]) + idx_back = np.random.choice(ranges[2]) + # select mel + mel_chunk_front = mel[idx_front : idx_front + chunk_frames, :] + mel_chunk_middle = mel[idx_middle : idx_middle + chunk_frames, :] + mel_chunk_back = mel[idx_back : idx_back + chunk_frames, :] + + # shrink the mel + mel_shrink = torchvision.transforms.Resize(size=[chunk_frames, 64])( + mel[None] + )[0] + # logging.info(f"mel_shrink.shape: {mel_shrink.shape}") + + # stack + mel_fusion = torch.stack( + [mel_chunk_front, mel_chunk_middle, mel_chunk_back, mel_shrink], + dim=0, + ) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([True]) + else: + raise NotImplementedError( + f"data_truncating {data_truncating} not implemented" + ) + # random crop to max_len (for compatibility) + overflow = len(audio_data) - max_len + idx = np.random.randint(0, overflow + 1) + audio_data = audio_data[idx : idx + max_len] + + else: # padding if too short + if len(audio_data) < max_len: # do nothing if equal + if data_filling == "repeatpad": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat) + # audio_data = audio_data.unsqueeze(0).unsqueeze(0).unsqueeze(0) + # audio_data = F.interpolate(audio_data,size=max_len,mode="bicubic")[0,0,0] + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "pad": + audio_data = F.pad( + audio_data, + (0, max_len - len(audio_data)), + mode="constant", + value=0, + ) + elif data_filling == "repeat": + n_repeat = int(max_len / len(audio_data)) + audio_data = audio_data.repeat(n_repeat + 1)[:max_len] + else: + raise NotImplementedError( + f"data_filling {data_filling} not implemented" + ) + if data_truncating == "fusion": + mel = get_mel(audio_data, audio_cfg) + mel_fusion = torch.stack([mel, mel, mel, mel], dim=0) + sample["mel_fusion"] = mel_fusion + longer = torch.tensor([False]) + + sample["longer"] = longer + sample["waveform"] = audio_data + + return sample + + +def preprocess( + sample, + audio_ext, + text_ext, + max_len, + audio_cfg, + class_index_dict=None, + data_filling="pad", + data_truncating="rand_trunc", + text_augment_selection=None, +): + """ + Preprocess a single sample for wdsdataloader. + """ + audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + audio_data = int16_to_float32(float32_to_int16(audio_data)) + audio_data = torch.tensor(audio_data).float() + + # TODO: (yusong) to be include in the future + # # if torchaudio not installed, use soundfile to load audio + # if torchaudio is None: + # audio_data, orig_sr = sf.read(io.BytesIO(sample[audio_ext])) + # audio_data = torch.tensor(audio_data).float() + # else: + # # https://github.com/webdataset/webdataset/blob/main/webdataset/autodecode.py + # with tempfile.TemporaryDirectory() as dirname: + # os.makedirs(dirname, exist_ok=True) + # fname = os.path.join(dirname, f"file.flac") + # with open(fname, "wb") as stream: + # stream.write(sample[audio_ext]) + # audio_data, orig_sr = torchaudio.load(fname) + # audio_data = audio_data[0, :].float() + + sample = get_audio_features( + sample, audio_data, max_len, data_truncating, data_filling, audio_cfg + ) + del sample[audio_ext] + + try: + json_dict_raw = json.loads(sample[text_ext].decode("utf-8")) + except: + print("sample[__url__]:", sample["__url__"]) + + # For selecting augmented text from dataset + if text_augment_selection is None or text_augment_selection == "none": + texts = json_dict_raw["text"] + elif text_augment_selection == "all": + if "text_augment_all" in json_dict_raw.keys(): + texts = json_dict_raw["text_augment_all"] + else: + texts = json_dict_raw["text"] + elif text_augment_selection == "augment_only": + if "text_augment_all" in json_dict_raw.keys(): + if json_dict_raw["text_augment_t5"] is None: + texts = json_dict_raw["text"] + else: + texts = json_dict_raw["text_augment_t5"] + else: + texts = json_dict_raw["text"] + else: + raise NotImplementedError( + f"text_augment_selection {text_augment_selection} not implemented" + ) + sample["full_text"] = texts + + if isinstance(texts, list) and isinstance(texts[0], str) and len(texts) > 1: + texts = random.choice(texts) + sample["raw_text"] = texts + sample["text"] = tokenizer(texts) # text shape: [num_token] + if class_index_dict is not None: + # https://stackoverflow.com/questions/48004243/how-to-share-large-read-only-dictionary-list-across-processes-in-multiprocessing + # https://stackoverflow.com/questions/45693949/storing-strings-in-a-multiprocessing-sharedctypes-array + # key, val = class_index_dict + # key = key[:].split('\n') + # _dict = {k: v for k, v in zip(key, val)} + sample["class_label"] = np.zeros(len(class_index_dict.keys())) + for x in json_dict_raw["tag"]: + sample["class_label"][class_index_dict[x]] = 1 + sample["class_label"] = torch.tensor(sample["class_label"]).float() + del sample[text_ext] + sample["audio_name"] = sample["__key__"].split("/")[-1] + "." + audio_ext + sample["text_name"] = sample["__key__"].split("/")[-1] + "." + text_ext + sample["audio_orig_sr"] = orig_sr + return sample + + +def collate_fn(batch): + """ + Collate function for wdsdataloader. + batch: a list of dict, each dict is a sample + """ + # concatenate values in each dictionary. if it is a tensor, concatenate. if it is a list, extend. + batch_dict = {} + for k in batch[0].keys(): + if isinstance(batch[0][k], dict): # dealwith bert tokenizer output + batch_dict[k] = {} + for kk in batch[0][k].keys(): + tmp = [] + for i in range(len(batch)): + tmp.append(batch[i][k][kk]) + batch_dict[k][kk] = torch.vstack(tmp) + elif isinstance(batch[0][k], torch.Tensor): + batch_dict[k] = torch.stack([sample[k] for sample in batch]) + elif isinstance(batch[0][k], np.ndarray): + batch_dict[k] = torch.tensor(np.stack([sample[k] for sample in batch])) + else: + batch_dict[k] = [sample[k] for sample in batch] + return batch_dict + + +def get_wds_dataset( + args, + model_cfg, + is_train, + audio_ext="flac", + text_ext="json", + max_len=480000, + proportion=1.0, + sizefilepath_=None, + is_local=None, +): + """ + Get a dataset for wdsdataloader. + """ + if is_local is None and (not args.remotedata is None): + is_local = not args.remotedata + + input_shards = args.train_data if is_train else args.val_data + assert input_shards is not None + + if not sizefilepath_ is None: + sizefilepath = sizefilepath_ + else: + sizefilepath = os.path.join(os.path.dirname(input_shards[0]), "sizes.json") + + if proportion != 1.0: + num_samples, num_shards, input_shards, _ = sample_prop( + sizefilepath, input_shards, proportion, is_local=is_local + ) + else: + num_samples, num_shards = get_dataset_size( + input_shards, sizefilepath_=sizefilepath_, is_local=is_local + ) + + if not num_samples: + if is_train: + num_samples = args.train_num_samples + if not num_samples: + raise RuntimeError( + "Currently, number of dataset samples must be specified for training dataset. " + "Please specify via `--train-num-samples` if no dataset length info present." + ) + else: + num_samples = ( + args.val_num_samples or 0 + ) # eval will just exhaust the iterator if not specified + + pipeline = [wds.SimpleShardList(input_shards)] + # at this point we have an iterator over all the shards + # TODO: (yusong): add a if statement of distributed. If not, we don't need to split_by_node + if is_train or args.parallel_eval: + pipeline.extend( + [ + wds.detshuffle( + bufsize=_SHARD_SHUFFLE_SIZE, + initial=_SHARD_SHUFFLE_INITIAL, + seed=args.seed, + ), + wds.split_by_node, + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker at each node + wds.tarfile_to_samples(handler=log_and_continue), + wds.shuffle( + bufsize=_SAMPLE_SHUFFLE_SIZE, + initial=_SAMPLE_SHUFFLE_INITIAL, + rng=random.Random(args.seed), + ), + # wds.repeatedly, # FIXME determine if this is beneficial + ] + ) + else: + pipeline.extend( + [ + wds.split_by_worker, + # at this point, we have an iterator over the shards assigned to each worker + wds.tarfile_to_samples(handler=log_and_continue), + ] + ) + pipeline.append( + wds.map( + partial( + preprocess, + audio_ext=audio_ext, + text_ext=text_ext, + max_len=max_len, + audio_cfg=model_cfg["audio_cfg"], + class_index_dict=copy.deepcopy(args.class_index_dict), + data_filling=args.data_filling, + data_truncating=args.data_truncating, + text_augment_selection=args.text_augment_selection, + ) + ), + ) + + pipeline.append( + wds.batched( + args.batch_size, + partial=not (is_train or args.parallel_eval), + collation_fn=collate_fn, + ) + ) + + dataset = wds.DataPipeline(*pipeline) + if is_train or args.parallel_eval: + # (yusong): Currently parallel evaluation will be not precise as we are repeat the last few samples. + # (yusong): See comments below. + # roll over and repeat a few samples to get same number of full batches on each node + global_batch_size = args.batch_size * args.world_size + num_batches = math.ceil(num_samples / global_batch_size) + num_workers = max(1, args.workers) + num_worker_batches = math.ceil( + num_batches / num_workers + ) # per dataloader worker + num_batches = num_worker_batches * num_workers + num_samples = num_batches * global_batch_size + dataset = dataset.with_epoch( + num_worker_batches + ) # each worker is iterating over this + else: + # last batches are partial, eval is done on single (master) node + num_batches = math.ceil(num_samples / args.batch_size) + + kwargs = {} + if args.horovod: # multi-node training on summit + kwargs["multiprocessing_context"] = "forkserver" + + dataloader = wds.WebLoader( + dataset, batch_size=None, shuffle=False, num_workers=args.workers, **kwargs + ) + + # FIXME not clear which approach is better, with_epoch before vs after dataloader? + # hoping to resolve via https://github.com/webdataset/webdataset/issues/169 + # if is_train: + # # roll over and repeat a few samples to get same number of full batches on each node + # global_batch_size = args.batch_size * args.world_size + # num_batches = math.ceil(num_samples / global_batch_size) + # num_workers = max(1, args.workers) + # num_batches = math.ceil(num_batches / num_workers) * num_workers + # num_samples = num_batches * global_batch_size + # dataloader = dataloader.with_epoch(num_batches) + # else: + # # last batches are partial, eval is done on single (master) node + # num_batches = math.ceil(num_samples / args.batch_size) + + # add meta-data to dataloader instance for convenience + dataloader.num_batches = num_batches + dataloader.num_samples = num_samples + + return DataInfo(dataloader, None) + + +def wds_batch_list2dict( + batch, + keys=[ + "__url__", + "__key__", + "waveform", + "text", + "raw_text", + "audio_name", + "text_name", + "audio_orig_sr", + ], +): + """ + Return a dictionary of the batch, with keys as the names of the fields. + """ + assert len(keys) == len( + batch + ), "batch must have same number of keys as keys argument" + return {keys[i]: batch[i] for i in range(len(batch))} + + +def get_csv_dataset(args, preprocess_fn, is_train): + input_filename = args.train_data if is_train else args.val_data + assert input_filename + dataset = CsvDataset( + input_filename, + preprocess_fn, + img_key=args.csv_img_key, + caption_key=args.csv_caption_key, + sep=args.csv_separator, + ) + num_samples = len(dataset) + sampler = DistributedSampler(dataset) if args.distributed and is_train else None + shuffle = is_train and sampler is None + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=shuffle, + num_workers=args.workers, + pin_memory=True, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_toy_dataset(args, model_cfg, is_train): + index_path = args.train_data if is_train else args.val_data + ipc_path = args.train_ipc if is_train else args.val_ipc + assert index_path and ipc_path + eval_mode = not is_train + dataset = ToyDataset(index_path, ipc_path, model_cfg, eval_mode=eval_mode) + + num_samples = len(dataset) + sampler = ( + DistributedSampler(dataset, shuffle=False) + if args.distributed and is_train + else None + ) + + dataloader = DataLoader( + dataset, + batch_size=args.batch_size, + shuffle=False, + num_workers=args.workers, + sampler=sampler, + drop_last=is_train, + ) + dataloader.num_samples = num_samples + dataloader.num_batches = len(dataloader) + + return DataInfo(dataloader, sampler) + + +def get_dataset_fn(data_path, dataset_type): + if dataset_type == "webdataset": + return get_wds_dataset + elif dataset_type == "csv": + return get_csv_dataset + elif dataset_type == "auto": + ext = data_path.split(".")[-1] + if ext in ["csv", "tsv"]: + return get_csv_dataset + elif ext in ["tar"]: + return get_wds_dataset + else: + raise ValueError( + f"Tried to figure out dataset type, but failed for extention {ext}." + ) + elif dataset_type == "toy": + return get_toy_dataset + else: + raise ValueError(f"Unsupported dataset type: {dataset_type}") + + +def get_data(args, model_cfg): + data = {} + + args.class_index_dict = load_class_label(args.class_label_path) + + if args.datasetinfos is None: + args.datasetinfos = ["train", "unbalanced_train", "balanced_train"] + if args.dataset_type == "webdataset": + args.train_data = get_tar_path_from_dataset_name( + args.datasetnames, + args.datasetinfos, + islocal=not args.remotedata, + proportion=args.dataset_proportion, + dataset_path=args.datasetpath, + full_dataset=args.full_train_dataset, + ) + + if args.full_train_dataset is None: + args.full_train_dataset = [] + if args.exclude_eval_dataset is None: + args.exclude_eval_dataset = [] + excluded_eval_datasets = args.full_train_dataset + args.exclude_eval_dataset + + val_dataset_names = ( + [n for n in args.datasetnames if n not in excluded_eval_datasets] + if excluded_eval_datasets + else args.datasetnames + ) + args.val_dataset_names = val_dataset_names + args.val_data = get_tar_path_from_dataset_name( + val_dataset_names, + ["valid", "test", "eval"], + islocal=not args.remotedata, + proportion=1, + dataset_path=args.datasetpath, + full_dataset=None, + ) + + if args.train_data: + data["train"] = get_dataset_fn(args.train_data, args.dataset_type)( + args, model_cfg, is_train=True + ) + + if args.val_data: + data["val"] = get_dataset_fn(args.val_data, args.dataset_type)( + args, model_cfg, is_train=False + ) + + return data diff --git a/audioldm/clap/training/distributed.py b/audioldm/clap/training/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa61f76c5cc3ab9f6a9643042afa8e1f2e1cb7f --- /dev/null +++ b/audioldm/clap/training/distributed.py @@ -0,0 +1,150 @@ +import os + +import torch +import socket + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + + +def is_global_master(args): + return args.rank == 0 + + +def is_local_master(args): + return args.local_rank == 0 + + +def is_master(args, local=False): + return is_local_master(args) if local else is_global_master(args) + + +def is_using_horovod(): + # NOTE w/ horovod run, OMPI vars should be set, but w/ SLURM PMI vars will be set + # Differentiating between horovod and DDP use via SLURM may not be possible, so horovod arg still required... + ompi_vars = ["OMPI_COMM_WORLD_RANK", "OMPI_COMM_WORLD_SIZE"] + pmi_vars = ["PMI_RANK", "PMI_SIZE"] + if all([var in os.environ for var in ompi_vars]) or all( + [var in os.environ for var in pmi_vars] + ): + return True + else: + return False + + +def is_using_distributed(): + if "WORLD_SIZE" in os.environ: + return int(os.environ["WORLD_SIZE"]) > 1 + if "SLURM_NTASKS" in os.environ: + return int(os.environ["SLURM_NTASKS"]) > 1 + return False + + +def world_info_from_env(): + local_rank = 0 + for v in ( + "SLURM_LOCALID", + "MPI_LOCALRANKID", + "OMPI_COMM_WORLD_LOCAL_RANK", + "LOCAL_RANK", + ): + if v in os.environ: + local_rank = int(os.environ[v]) + break + global_rank = 0 + for v in ("SLURM_PROCID", "PMI_RANK", "OMPI_COMM_WORLD_RANK", "RANK"): + if v in os.environ: + global_rank = int(os.environ[v]) + break + world_size = 1 + for v in ("SLURM_NTASKS", "PMI_SIZE", "OMPI_COMM_WORLD_SIZE", "WORLD_SIZE"): + if v in os.environ: + world_size = int(os.environ[v]) + break + + return local_rank, global_rank, world_size + + +def init_distributed_device(args): + # Distributed training = training on more than one GPU. + # Works in both single and multi-node scenarios. + args.distributed = False + args.world_size = 1 + args.rank = 0 # global rank + args.local_rank = 0 + if args.horovod: + assert hvd is not None, "Horovod is not installed" + hvd.init() + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + # args.local_rank = int(hvd.local_rank()) + # args.rank = hvd.rank() + # args.world_size = hvd.size() + args.distributed = True + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + elif is_using_distributed(): + if "SLURM_PROCID" in os.environ: + # DDP via SLURM + args.local_rank, args.rank, args.world_size = world_info_from_env() + # SLURM var -> torch.distributed vars in case needed + os.environ["LOCAL_RANK"] = str(args.local_rank) + os.environ["RANK"] = str(args.rank) + os.environ["WORLD_SIZE"] = str(args.world_size) + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + elif "OMPI_COMM_WORLD_SIZE" in os.environ: # using Summit cluster + world_size = int(os.environ["OMPI_COMM_WORLD_SIZE"]) + world_rank = int(os.environ["OMPI_COMM_WORLD_RANK"]) + local_rank = int(os.environ["OMPI_COMM_WORLD_LOCAL_RANK"]) + args.local_rank = local_rank + args.rank = world_rank + args.world_size = world_size + torch.distributed.init_process_group( + backend=args.dist_backend, + init_method=args.dist_url, + world_size=args.world_size, + rank=args.rank, + ) + else: + # DDP via torchrun, torch.distributed.launch + args.local_rank, _, _ = world_info_from_env() + torch.distributed.init_process_group( + backend=args.dist_backend, init_method=args.dist_url + ) + args.world_size = torch.distributed.get_world_size() + args.rank = torch.distributed.get_rank() + args.distributed = True + print( + f"Distributed training: local_rank={args.local_rank}, " + f"rank={args.rank}, world_size={args.world_size}, " + f"hostname={socket.gethostname()}, pid={os.getpid()}" + ) + + if torch.cuda.is_available(): + if args.distributed and not args.no_set_device_rank: + device = "cuda:%d" % args.local_rank + else: + device = "cuda:0" + torch.cuda.set_device(device) + else: + device = "cpu" + args.device = device + device = torch.device(device) + return device diff --git a/audioldm/clap/training/imagenet_zeroshot_data.py b/audioldm/clap/training/imagenet_zeroshot_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d32e55328d6799ccb8d61625f43abb80a33d6c17 --- /dev/null +++ b/audioldm/clap/training/imagenet_zeroshot_data.py @@ -0,0 +1,1088 @@ +# NOTE: This script is currently not supported for CLAP. + +imagenet_classnames = [ + "tench", + "goldfish", + "great white shark", + "tiger shark", + "hammerhead shark", + "electric ray", + "stingray", + "rooster", + "hen", + "ostrich", + "brambling", + "goldfinch", + "house finch", + "junco", + "indigo bunting", + "American robin", + "bulbul", + "jay", + "magpie", + "chickadee", + "American dipper", + "kite (bird of prey)", + "bald eagle", + "vulture", + "great grey owl", + "fire salamander", + "smooth newt", + "newt", + "spotted salamander", + "axolotl", + "American bullfrog", + "tree frog", + "tailed frog", + "loggerhead sea turtle", + "leatherback sea turtle", + "mud turtle", + "terrapin", + "box turtle", + "banded gecko", + "green iguana", + "Carolina anole", + "desert grassland whiptail lizard", + "agama", + "frilled-necked lizard", + "alligator lizard", + "Gila monster", + "European green lizard", + "chameleon", + "Komodo dragon", + "Nile crocodile", + "American alligator", + "triceratops", + "worm snake", + "ring-necked snake", + "eastern hog-nosed snake", + "smooth green snake", + "kingsnake", + "garter snake", + "water snake", + "vine snake", + "night snake", + "boa constrictor", + "African rock python", + "Indian cobra", + "green mamba", + "sea snake", + "Saharan horned viper", + "eastern diamondback rattlesnake", + "sidewinder rattlesnake", + "trilobite", + "harvestman", + "scorpion", + "yellow garden spider", + "barn spider", + "European garden spider", + "southern black widow", + "tarantula", + "wolf spider", + "tick", + "centipede", + "black grouse", + "ptarmigan", + "ruffed grouse", + "prairie grouse", + "peafowl", + "quail", + "partridge", + "african grey parrot", + "macaw", + "sulphur-crested cockatoo", + "lorikeet", + "coucal", + "bee eater", + "hornbill", + "hummingbird", + "jacamar", + "toucan", + "duck", + "red-breasted merganser", + "goose", + "black swan", + "tusker", + "echidna", + "platypus", + "wallaby", + "koala", + "wombat", + "jellyfish", + "sea anemone", + "brain coral", + "flatworm", + "nematode", + "conch", + "snail", + "slug", + "sea slug", + "chiton", + "chambered nautilus", + "Dungeness crab", + "rock crab", + "fiddler crab", + "red king crab", + "American lobster", + "spiny lobster", + "crayfish", + "hermit crab", + "isopod", + "white stork", + "black stork", + "spoonbill", + "flamingo", + "little blue heron", + "great egret", + "bittern bird", + "crane bird", + "limpkin", + "common gallinule", + "American coot", + "bustard", + "ruddy turnstone", + "dunlin", + "common redshank", + "dowitcher", + "oystercatcher", + "pelican", + "king penguin", + "albatross", + "grey whale", + "killer whale", + "dugong", + "sea lion", + "Chihuahua", + "Japanese Chin", + "Maltese", + "Pekingese", + "Shih Tzu", + "King Charles Spaniel", + "Papillon", + "toy terrier", + "Rhodesian Ridgeback", + "Afghan Hound", + "Basset Hound", + "Beagle", + "Bloodhound", + "Bluetick Coonhound", + "Black and Tan Coonhound", + "Treeing Walker Coonhound", + "English foxhound", + "Redbone Coonhound", + "borzoi", + "Irish Wolfhound", + "Italian Greyhound", + "Whippet", + "Ibizan Hound", + "Norwegian Elkhound", + "Otterhound", + "Saluki", + "Scottish Deerhound", + "Weimaraner", + "Staffordshire Bull Terrier", + "American Staffordshire Terrier", + "Bedlington Terrier", + "Border Terrier", + "Kerry Blue Terrier", + "Irish Terrier", + "Norfolk Terrier", + "Norwich Terrier", + "Yorkshire Terrier", + "Wire Fox Terrier", + "Lakeland Terrier", + "Sealyham Terrier", + "Airedale Terrier", + "Cairn Terrier", + "Australian Terrier", + "Dandie Dinmont Terrier", + "Boston Terrier", + "Miniature Schnauzer", + "Giant Schnauzer", + "Standard Schnauzer", + "Scottish Terrier", + "Tibetan Terrier", + "Australian Silky Terrier", + "Soft-coated Wheaten Terrier", + "West Highland White Terrier", + "Lhasa Apso", + "Flat-Coated Retriever", + "Curly-coated Retriever", + "Golden Retriever", + "Labrador Retriever", + "Chesapeake Bay Retriever", + "German Shorthaired Pointer", + "Vizsla", + "English Setter", + "Irish Setter", + "Gordon Setter", + "Brittany dog", + "Clumber Spaniel", + "English Springer Spaniel", + "Welsh Springer Spaniel", + "Cocker Spaniel", + "Sussex Spaniel", + "Irish Water Spaniel", + "Kuvasz", + "Schipperke", + "Groenendael dog", + "Malinois", + "Briard", + "Australian Kelpie", + "Komondor", + "Old English Sheepdog", + "Shetland Sheepdog", + "collie", + "Border Collie", + "Bouvier des Flandres dog", + "Rottweiler", + "German Shepherd Dog", + "Dobermann", + "Miniature Pinscher", + "Greater Swiss Mountain Dog", + "Bernese Mountain Dog", + "Appenzeller Sennenhund", + "Entlebucher Sennenhund", + "Boxer", + "Bullmastiff", + "Tibetan Mastiff", + "French Bulldog", + "Great Dane", + "St. Bernard", + "husky", + "Alaskan Malamute", + "Siberian Husky", + "Dalmatian", + "Affenpinscher", + "Basenji", + "pug", + "Leonberger", + "Newfoundland dog", + "Great Pyrenees dog", + "Samoyed", + "Pomeranian", + "Chow Chow", + "Keeshond", + "brussels griffon", + "Pembroke Welsh Corgi", + "Cardigan Welsh Corgi", + "Toy Poodle", + "Miniature Poodle", + "Standard Poodle", + "Mexican hairless dog (xoloitzcuintli)", + "grey wolf", + "Alaskan tundra wolf", + "red wolf or maned wolf", + "coyote", + "dingo", + "dhole", + "African wild dog", + "hyena", + "red fox", + "kit fox", + "Arctic fox", + "grey fox", + "tabby cat", + "tiger cat", + "Persian cat", + "Siamese cat", + "Egyptian Mau", + "cougar", + "lynx", + "leopard", + "snow leopard", + "jaguar", + "lion", + "tiger", + "cheetah", + "brown bear", + "American black bear", + "polar bear", + "sloth bear", + "mongoose", + "meerkat", + "tiger beetle", + "ladybug", + "ground beetle", + "longhorn beetle", + "leaf beetle", + "dung beetle", + "rhinoceros beetle", + "weevil", + "fly", + "bee", + "ant", + "grasshopper", + "cricket insect", + "stick insect", + "cockroach", + "praying mantis", + "cicada", + "leafhopper", + "lacewing", + "dragonfly", + "damselfly", + "red admiral butterfly", + "ringlet butterfly", + "monarch butterfly", + "small white butterfly", + "sulphur butterfly", + "gossamer-winged butterfly", + "starfish", + "sea urchin", + "sea cucumber", + "cottontail rabbit", + "hare", + "Angora rabbit", + "hamster", + "porcupine", + "fox squirrel", + "marmot", + "beaver", + "guinea pig", + "common sorrel horse", + "zebra", + "pig", + "wild boar", + "warthog", + "hippopotamus", + "ox", + "water buffalo", + "bison", + "ram (adult male sheep)", + "bighorn sheep", + "Alpine ibex", + "hartebeest", + "impala (antelope)", + "gazelle", + "arabian camel", + "llama", + "weasel", + "mink", + "European polecat", + "black-footed ferret", + "otter", + "skunk", + "badger", + "armadillo", + "three-toed sloth", + "orangutan", + "gorilla", + "chimpanzee", + "gibbon", + "siamang", + "guenon", + "patas monkey", + "baboon", + "macaque", + "langur", + "black-and-white colobus", + "proboscis monkey", + "marmoset", + "white-headed capuchin", + "howler monkey", + "titi monkey", + "Geoffroy's spider monkey", + "common squirrel monkey", + "ring-tailed lemur", + "indri", + "Asian elephant", + "African bush elephant", + "red panda", + "giant panda", + "snoek fish", + "eel", + "silver salmon", + "rock beauty fish", + "clownfish", + "sturgeon", + "gar fish", + "lionfish", + "pufferfish", + "abacus", + "abaya", + "academic gown", + "accordion", + "acoustic guitar", + "aircraft carrier", + "airliner", + "airship", + "altar", + "ambulance", + "amphibious vehicle", + "analog clock", + "apiary", + "apron", + "trash can", + "assault rifle", + "backpack", + "bakery", + "balance beam", + "balloon", + "ballpoint pen", + "Band-Aid", + "banjo", + "baluster / handrail", + "barbell", + "barber chair", + "barbershop", + "barn", + "barometer", + "barrel", + "wheelbarrow", + "baseball", + "basketball", + "bassinet", + "bassoon", + "swimming cap", + "bath towel", + "bathtub", + "station wagon", + "lighthouse", + "beaker", + "military hat (bearskin or shako)", + "beer bottle", + "beer glass", + "bell tower", + "baby bib", + "tandem bicycle", + "bikini", + "ring binder", + "binoculars", + "birdhouse", + "boathouse", + "bobsleigh", + "bolo tie", + "poke bonnet", + "bookcase", + "bookstore", + "bottle cap", + "hunting bow", + "bow tie", + "brass memorial plaque", + "bra", + "breakwater", + "breastplate", + "broom", + "bucket", + "buckle", + "bulletproof vest", + "high-speed train", + "butcher shop", + "taxicab", + "cauldron", + "candle", + "cannon", + "canoe", + "can opener", + "cardigan", + "car mirror", + "carousel", + "tool kit", + "cardboard box / carton", + "car wheel", + "automated teller machine", + "cassette", + "cassette player", + "castle", + "catamaran", + "CD player", + "cello", + "mobile phone", + "chain", + "chain-link fence", + "chain mail", + "chainsaw", + "storage chest", + "chiffonier", + "bell or wind chime", + "china cabinet", + "Christmas stocking", + "church", + "movie theater", + "cleaver", + "cliff dwelling", + "cloak", + "clogs", + "cocktail shaker", + "coffee mug", + "coffeemaker", + "spiral or coil", + "combination lock", + "computer keyboard", + "candy store", + "container ship", + "convertible", + "corkscrew", + "cornet", + "cowboy boot", + "cowboy hat", + "cradle", + "construction crane", + "crash helmet", + "crate", + "infant bed", + "Crock Pot", + "croquet ball", + "crutch", + "cuirass", + "dam", + "desk", + "desktop computer", + "rotary dial telephone", + "diaper", + "digital clock", + "digital watch", + "dining table", + "dishcloth", + "dishwasher", + "disc brake", + "dock", + "dog sled", + "dome", + "doormat", + "drilling rig", + "drum", + "drumstick", + "dumbbell", + "Dutch oven", + "electric fan", + "electric guitar", + "electric locomotive", + "entertainment center", + "envelope", + "espresso machine", + "face powder", + "feather boa", + "filing cabinet", + "fireboat", + "fire truck", + "fire screen", + "flagpole", + "flute", + "folding chair", + "football helmet", + "forklift", + "fountain", + "fountain pen", + "four-poster bed", + "freight car", + "French horn", + "frying pan", + "fur coat", + "garbage truck", + "gas mask or respirator", + "gas pump", + "goblet", + "go-kart", + "golf ball", + "golf cart", + "gondola", + "gong", + "gown", + "grand piano", + "greenhouse", + "radiator grille", + "grocery store", + "guillotine", + "hair clip", + "hair spray", + "half-track", + "hammer", + "hamper", + "hair dryer", + "hand-held computer", + "handkerchief", + "hard disk drive", + "harmonica", + "harp", + "combine harvester", + "hatchet", + "holster", + "home theater", + "honeycomb", + "hook", + "hoop skirt", + "gymnastic horizontal bar", + "horse-drawn vehicle", + "hourglass", + "iPod", + "clothes iron", + "carved pumpkin", + "jeans", + "jeep", + "T-shirt", + "jigsaw puzzle", + "rickshaw", + "joystick", + "kimono", + "knee pad", + "knot", + "lab coat", + "ladle", + "lampshade", + "laptop computer", + "lawn mower", + "lens cap", + "letter opener", + "library", + "lifeboat", + "lighter", + "limousine", + "ocean liner", + "lipstick", + "slip-on shoe", + "lotion", + "music speaker", + "loupe magnifying glass", + "sawmill", + "magnetic compass", + "messenger bag", + "mailbox", + "tights", + "one-piece bathing suit", + "manhole cover", + "maraca", + "marimba", + "mask", + "matchstick", + "maypole", + "maze", + "measuring cup", + "medicine cabinet", + "megalith", + "microphone", + "microwave oven", + "military uniform", + "milk can", + "minibus", + "miniskirt", + "minivan", + "missile", + "mitten", + "mixing bowl", + "mobile home", + "ford model t", + "modem", + "monastery", + "monitor", + "moped", + "mortar and pestle", + "graduation cap", + "mosque", + "mosquito net", + "vespa", + "mountain bike", + "tent", + "computer mouse", + "mousetrap", + "moving van", + "muzzle", + "metal nail", + "neck brace", + "necklace", + "baby pacifier", + "notebook computer", + "obelisk", + "oboe", + "ocarina", + "odometer", + "oil filter", + "pipe organ", + "oscilloscope", + "overskirt", + "bullock cart", + "oxygen mask", + "product packet / packaging", + "paddle", + "paddle wheel", + "padlock", + "paintbrush", + "pajamas", + "palace", + "pan flute", + "paper towel", + "parachute", + "parallel bars", + "park bench", + "parking meter", + "railroad car", + "patio", + "payphone", + "pedestal", + "pencil case", + "pencil sharpener", + "perfume", + "Petri dish", + "photocopier", + "plectrum", + "Pickelhaube", + "picket fence", + "pickup truck", + "pier", + "piggy bank", + "pill bottle", + "pillow", + "ping-pong ball", + "pinwheel", + "pirate ship", + "drink pitcher", + "block plane", + "planetarium", + "plastic bag", + "plate rack", + "farm plow", + "plunger", + "Polaroid camera", + "pole", + "police van", + "poncho", + "pool table", + "soda bottle", + "plant pot", + "potter's wheel", + "power drill", + "prayer rug", + "printer", + "prison", + "missile", + "projector", + "hockey puck", + "punching bag", + "purse", + "quill", + "quilt", + "race car", + "racket", + "radiator", + "radio", + "radio telescope", + "rain barrel", + "recreational vehicle", + "fishing casting reel", + "reflex camera", + "refrigerator", + "remote control", + "restaurant", + "revolver", + "rifle", + "rocking chair", + "rotisserie", + "eraser", + "rugby ball", + "ruler measuring stick", + "sneaker", + "safe", + "safety pin", + "salt shaker", + "sandal", + "sarong", + "saxophone", + "scabbard", + "weighing scale", + "school bus", + "schooner", + "scoreboard", + "CRT monitor", + "screw", + "screwdriver", + "seat belt", + "sewing machine", + "shield", + "shoe store", + "shoji screen / room divider", + "shopping basket", + "shopping cart", + "shovel", + "shower cap", + "shower curtain", + "ski", + "balaclava ski mask", + "sleeping bag", + "slide rule", + "sliding door", + "slot machine", + "snorkel", + "snowmobile", + "snowplow", + "soap dispenser", + "soccer ball", + "sock", + "solar thermal collector", + "sombrero", + "soup bowl", + "keyboard space bar", + "space heater", + "space shuttle", + "spatula", + "motorboat", + "spider web", + "spindle", + "sports car", + "spotlight", + "stage", + "steam locomotive", + "through arch bridge", + "steel drum", + "stethoscope", + "scarf", + "stone wall", + "stopwatch", + "stove", + "strainer", + "tram", + "stretcher", + "couch", + "stupa", + "submarine", + "suit", + "sundial", + "sunglasses", + "sunglasses", + "sunscreen", + "suspension bridge", + "mop", + "sweatshirt", + "swim trunks / shorts", + "swing", + "electrical switch", + "syringe", + "table lamp", + "tank", + "tape player", + "teapot", + "teddy bear", + "television", + "tennis ball", + "thatched roof", + "front curtain", + "thimble", + "threshing machine", + "throne", + "tile roof", + "toaster", + "tobacco shop", + "toilet seat", + "torch", + "totem pole", + "tow truck", + "toy store", + "tractor", + "semi-trailer truck", + "tray", + "trench coat", + "tricycle", + "trimaran", + "tripod", + "triumphal arch", + "trolleybus", + "trombone", + "hot tub", + "turnstile", + "typewriter keyboard", + "umbrella", + "unicycle", + "upright piano", + "vacuum cleaner", + "vase", + "vaulted or arched ceiling", + "velvet fabric", + "vending machine", + "vestment", + "viaduct", + "violin", + "volleyball", + "waffle iron", + "wall clock", + "wallet", + "wardrobe", + "military aircraft", + "sink", + "washing machine", + "water bottle", + "water jug", + "water tower", + "whiskey jug", + "whistle", + "hair wig", + "window screen", + "window shade", + "Windsor tie", + "wine bottle", + "airplane wing", + "wok", + "wooden spoon", + "wool", + "split-rail fence", + "shipwreck", + "sailboat", + "yurt", + "website", + "comic book", + "crossword", + "traffic or street sign", + "traffic light", + "dust jacket", + "menu", + "plate", + "guacamole", + "consomme", + "hot pot", + "trifle", + "ice cream", + "popsicle", + "baguette", + "bagel", + "pretzel", + "cheeseburger", + "hot dog", + "mashed potatoes", + "cabbage", + "broccoli", + "cauliflower", + "zucchini", + "spaghetti squash", + "acorn squash", + "butternut squash", + "cucumber", + "artichoke", + "bell pepper", + "cardoon", + "mushroom", + "Granny Smith apple", + "strawberry", + "orange", + "lemon", + "fig", + "pineapple", + "banana", + "jackfruit", + "cherimoya (custard apple)", + "pomegranate", + "hay", + "carbonara", + "chocolate syrup", + "dough", + "meatloaf", + "pizza", + "pot pie", + "burrito", + "red wine", + "espresso", + "tea cup", + "eggnog", + "mountain", + "bubble", + "cliff", + "coral reef", + "geyser", + "lakeshore", + "promontory", + "sandbar", + "beach", + "valley", + "volcano", + "baseball player", + "bridegroom", + "scuba diver", + "rapeseed", + "daisy", + "yellow lady's slipper", + "corn", + "acorn", + "rose hip", + "horse chestnut seed", + "coral fungus", + "agaric", + "gyromitra", + "stinkhorn mushroom", + "earth star fungus", + "hen of the woods mushroom", + "bolete", + "corn cob", + "toilet paper", +] + + +openai_imagenet_template = [ + lambda c: f"a bad photo of a {c}.", + lambda c: f"a photo of many {c}.", + lambda c: f"a sculpture of a {c}.", + lambda c: f"a photo of the hard to see {c}.", + lambda c: f"a low resolution photo of the {c}.", + lambda c: f"a rendering of a {c}.", + lambda c: f"graffiti of a {c}.", + lambda c: f"a bad photo of the {c}.", + lambda c: f"a cropped photo of the {c}.", + lambda c: f"a tattoo of a {c}.", + lambda c: f"the embroidered {c}.", + lambda c: f"a photo of a hard to see {c}.", + lambda c: f"a bright photo of a {c}.", + lambda c: f"a photo of a clean {c}.", + lambda c: f"a photo of a dirty {c}.", + lambda c: f"a dark photo of the {c}.", + lambda c: f"a drawing of a {c}.", + lambda c: f"a photo of my {c}.", + lambda c: f"the plastic {c}.", + lambda c: f"a photo of the cool {c}.", + lambda c: f"a close-up photo of a {c}.", + lambda c: f"a black and white photo of the {c}.", + lambda c: f"a painting of the {c}.", + lambda c: f"a painting of a {c}.", + lambda c: f"a pixelated photo of the {c}.", + lambda c: f"a sculpture of the {c}.", + lambda c: f"a bright photo of the {c}.", + lambda c: f"a cropped photo of a {c}.", + lambda c: f"a plastic {c}.", + lambda c: f"a photo of the dirty {c}.", + lambda c: f"a jpeg corrupted photo of a {c}.", + lambda c: f"a blurry photo of the {c}.", + lambda c: f"a photo of the {c}.", + lambda c: f"a good photo of the {c}.", + lambda c: f"a rendering of the {c}.", + lambda c: f"a {c} in a video game.", + lambda c: f"a photo of one {c}.", + lambda c: f"a doodle of a {c}.", + lambda c: f"a close-up photo of the {c}.", + lambda c: f"a photo of a {c}.", + lambda c: f"the origami {c}.", + lambda c: f"the {c} in a video game.", + lambda c: f"a sketch of a {c}.", + lambda c: f"a doodle of the {c}.", + lambda c: f"a origami {c}.", + lambda c: f"a low resolution photo of a {c}.", + lambda c: f"the toy {c}.", + lambda c: f"a rendition of the {c}.", + lambda c: f"a photo of the clean {c}.", + lambda c: f"a photo of a large {c}.", + lambda c: f"a rendition of a {c}.", + lambda c: f"a photo of a nice {c}.", + lambda c: f"a photo of a weird {c}.", + lambda c: f"a blurry photo of a {c}.", + lambda c: f"a cartoon {c}.", + lambda c: f"art of a {c}.", + lambda c: f"a sketch of the {c}.", + lambda c: f"a embroidered {c}.", + lambda c: f"a pixelated photo of a {c}.", + lambda c: f"itap of the {c}.", + lambda c: f"a jpeg corrupted photo of the {c}.", + lambda c: f"a good photo of a {c}.", + lambda c: f"a plushie {c}.", + lambda c: f"a photo of the nice {c}.", + lambda c: f"a photo of the small {c}.", + lambda c: f"a photo of the weird {c}.", + lambda c: f"the cartoon {c}.", + lambda c: f"art of the {c}.", + lambda c: f"a drawing of the {c}.", + lambda c: f"a photo of the large {c}.", + lambda c: f"a black and white photo of a {c}.", + lambda c: f"the plushie {c}.", + lambda c: f"a dark photo of a {c}.", + lambda c: f"itap of a {c}.", + lambda c: f"graffiti of the {c}.", + lambda c: f"a toy {c}.", + lambda c: f"itap of my {c}.", + lambda c: f"a photo of a cool {c}.", + lambda c: f"a photo of a small {c}.", + lambda c: f"a tattoo of the {c}.", +] diff --git a/audioldm/clap/training/infer_demo.py b/audioldm/clap/training/infer_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..7d1f4784898dbfeb69affefb6f624711adc8cb42 --- /dev/null +++ b/audioldm/clap/training/infer_demo.py @@ -0,0 +1,105 @@ +import sys + +import os +import torch +import librosa +from open_clip import create_model +from training.data import get_audio_features +from training.data import int16_to_float32, float32_to_int16 +from transformers import RobertaTokenizer + +tokenize = RobertaTokenizer.from_pretrained("roberta-base") + + +def tokenizer(text): + result = tokenize( + text, + padding="max_length", + truncation=True, + max_length=77, + return_tensors="pt", + ) + return {k: v.squeeze(0) for k, v in result.items()} + + +PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" +WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" + + +def infer_text(): + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + # load the text, can be a list (i.e. batch size) + text_data = ["I love the contrastive learning", "I love the pretrain model"] + # tokenize for roberta, if you want to tokenize for another text encoder, please refer to data.py#L43-90 + text_data = tokenizer(text_data) + + text_embed = model.get_text_embedding(text_data) + print(text_embed.size()) + + +def infer_audio(): + + device = "cuda:0" if torch.cuda.is_available() else "cpu" + precision = "fp32" + amodel = "HTSAT-tiny" # or 'PANN-14' + tmodel = "roberta" # the best text encoder in our training + enable_fusion = False # False if you do not want to use the fusion model + fusion_type = "aff_2d" + pretrained = PRETRAINED_PATH + + model, model_cfg = create_model( + amodel, + tmodel, + pretrained, + precision=precision, + device=device, + enable_fusion=enable_fusion, + fusion_type=fusion_type, + ) + + # load the waveform of the shape (T,), should resample to 48000 + audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) + # quantize + audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) + audio_waveform = torch.from_numpy(audio_waveform).float() + audio_dict = {} + + # the 'fusion' truncate mode can be changed to 'rand_trunc' if run in unfusion mode + import ipdb + + ipdb.set_trace() + audio_dict = get_audio_features( + audio_dict, + audio_waveform, + 480000, + data_truncating="fusion", + data_filling="repeatpad", + audio_cfg=model_cfg["audio_cfg"], + ) + # can send a list to the model, to process many audio tracks in one time (i.e. batch size) + audio_embed = model.get_audio_embedding([audio_dict]) + print(audio_embed.size()) + import ipdb + + ipdb.set_trace() + + +if __name__ == "__main__": + infer_text() + infer_audio() diff --git a/audioldm/clap/training/logger.py b/audioldm/clap/training/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..ac4634970fae6aacde2b7b808355dbd50c90ce73 --- /dev/null +++ b/audioldm/clap/training/logger.py @@ -0,0 +1,30 @@ +import logging + + +def setup_logging(log_file, level, include_host=False): + if include_host: + import socket + + hostname = socket.gethostname() + formatter = logging.Formatter( + f"%(asctime)s | {hostname} | %(levelname)s | %(message)s", + datefmt="%Y-%m-%d,%H:%M:%S", + ) + else: + formatter = logging.Formatter( + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d,%H:%M:%S" + ) + + logging.root.setLevel(level) + loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict] + for logger in loggers: + logger.setLevel(level) + + stream_handler = logging.StreamHandler() + stream_handler.setFormatter(formatter) + logging.root.addHandler(stream_handler) + + if log_file: + file_handler = logging.FileHandler(filename=log_file) + file_handler.setFormatter(formatter) + logging.root.addHandler(file_handler) diff --git a/audioldm/clap/training/lp_main.py b/audioldm/clap/training/lp_main.py new file mode 100644 index 0000000000000000000000000000000000000000..c2d4e8c85aaa3c8e4221963ef56a815cc14f354f --- /dev/null +++ b/audioldm/clap/training/lp_main.py @@ -0,0 +1,670 @@ +from cmath import cos +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +from sched import scheduler +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib +import argparse +import time + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.params import parse_args +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.scheduler import cosine_lr +from training.lp_train import train_one_epoch, evaluate +from open_clip.utils import get_tar_path_from_dataset_name, dataset_split, get_optimizer +from open_clip.utils import load_p, load_class_label +from open_clip.linear_probe import LinearProbe + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + or n.startswith("clap_model.logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def config_lp_optimizer(model, data, args): + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + in_clap = lambda n, p: n.startswith("clap_model") + + named_parameters = list(model.named_parameters()) + + optimizer = {} + scheduler = {} + + # freeze text encoder + text_freeze_parameters = [ + p + for n, p in named_parameters + if n.startswith("clap_model.transformer") + or n in ["clap_model.positional_embedding", "clap_model.text_projection"] + or n.startswith("clap_model.token_embedding") + or n.startswith("clap_model.ln_final") + ] + + if args.freeze_text: + logging.info("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + if not args.lp_freeze: + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + # (yusong): we do not split the learning rate anymore + # p for n, p in named_parameters if in_clap(n,p) and exclude(n, p) and p.requires_grad + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + # rest_params = [p for n, p in named_parameters if in_clap(n,p) and include(n, p) and p.requires_grad] + rest_params = [ + p for n, p in named_parameters if include(n, p) and p.requires_grad + ] + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) + and (not is_pretrained_params(n)) + ] + + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer["text"] = pretrained_params_optimizer + optimizer["audio"] = new_params_optimizer + scheduler["text"] = pretrained_params_scheduler + scheduler["audio"] = new_params_scheduler + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state( + pretrained_params_optimizer, root_rank=0 + ) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + + optimizer["clap"] = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + scheduler["clap"] = cosine_lr( + optimizer["clap"], args.lr, args.warmup, total_steps + ) + + if args.horovod: + optimizer["clap"] = hvd.DistributedOptimizer( + optimizer["clap"], named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer["clap"], root_rank=0) + + # linear probe optimizer + else: + lp_params = [ + p for n, p in named_parameters if (not in_clap(n, p)) and p.requires_grad + ] + lp_optim = get_optimizer( + lp_params, + lr=args.lp_lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=0.9, + optimizer_name=args.optimizer, + ) + optimizer["lp"] = lp_optim + + return optimizer, scheduler, text_freeze_parameters + + +def main(): + args = parse_args() + + time.sleep(args.sleep) + + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + args.class_index_dict = load_class_label(args.class_label_path) + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"linear_probe" f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + + # avoid log dir in same name: + postfix = 0 + while os.path.exists(args.log_path): + postfix += 1 + log_base_path_new = log_base_path + "-" + str(postfix) + os.makedirs(log_base_path_new, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path_new, log_filename) + # print( + # "Error. Experiment already exists. Use --name {} to specify a new experiment." + # ) + # return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + # Create CLAP model + clap_model, clap_model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=False, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + args.lp_out_ch = len(list(args.class_index_dict.keys())) + # Linear Probe + logging.info(f"linear probe using mlp: {args.lp_mlp}") + logging.info(f"linear probe using freeze: {args.lp_freeze}") + logging.info(f"linear probe act layer: {args.lp_act}") + logging.info(f"linear probe out ch: {args.lp_out_ch}") + logging.info(f"linear probe learning rate (if applicable): {args.lp_lr}") + logging.info(f"linear probe loss func: {args.lp_loss}") + logging.info(f"linear probe lp_metrics: {args.lp_metrics}") + + model = LinearProbe( + clap_model, + mlp=args.lp_mlp, + freeze=args.lp_freeze, + in_ch=512, + out_ch=args.lp_out_ch, + act=args.lp_act, + ) # in_ch is fixed (i.e., 512) + model = model.to(device) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Linear Probe CLAP Model:") + logging.info(f"{str(clap_model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, clap_model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + optimizer, scheduler, text_freeze_parameters = config_lp_optimizer( + model, data, args + ) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/audioldm/clap/training/lp_train.py b/audioldm/clap/training/lp_train.py new file mode 100644 index 0000000000000000000000000000000000000000..24a19bacd0a4b789415cfccbce1f8bc99bc493ed --- /dev/null +++ b/audioldm/clap/training/lp_train.py @@ -0,0 +1,301 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import LPLoss, LPMetrics, lp_gather_features +from open_clip.utils import do_mixup, get_mix_lambda +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, + data, + epoch, + optimizer, + scaler, + scheduler, + args, + tb_writer=None, + extra_suffix="", +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = LPLoss(args.lp_loss) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + step = num_batches_per_epoch * epoch + i + + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + if args.mixup: + # https://github.com/RetroCirce/HTS-Audio-Transformer/blob/main/utils.py#L146 + mix_lambda = torch.from_numpy( + get_mix_lambda(0.5, len(audio["waveform"])) + ).to(device) + class_label = do_mixup(class_label, mix_lambda) + else: + mix_lambda = None + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + pred = model(audio, mix_lambda=mix_lambda, device=device) + total_loss = loss(pred, class_label) + + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).clap_model.logit_scale_a.clamp_(0, math.log(100)) + unwrap_model(model).clap_model.logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audio, dict): + batch_size = len(audio["waveform"]) + else: + batch_size = len(audio) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + if isinstance(optimizer, dict): + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = f"train{extra_suffix}/{name}" + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None, extra_suffix=""): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + metric_names = args.lp_metrics.split(",") + eval_tool = LPMetrics(metric_names=metric_names) + + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + if args.parallel_eval: + dataloader, sampler = data["val"].dataloader, data["val"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + samples_per_val = dataloader.num_samples + else: + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + eval_info = {"pred": [], "target": []} + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audio = batch # contains mel_spec, wavform, and longer list + class_label = batch["class_label"] + + # audio = audio.to(device=device, non_blocking=True) + class_label = class_label.to(device=device, non_blocking=True) + + with autocast(): + pred = model(audio, device=device) + if args.parallel_eval: + pred, class_label = lp_gather_features( + pred, class_label, args.world_size, args.horovod + ) + eval_info["pred"].append(pred) + eval_info["target"].append(class_label) + + num_samples += class_label.shape[0] + + if (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + + if is_master(args): + eval_info["pred"] = torch.cat(eval_info["pred"], 0).cpu() + eval_info["target"] = torch.cat(eval_info["target"], 0).cpu() + metric_dict = eval_tool.evaluate_mertics( + eval_info["pred"], eval_info["target"] + ) + metrics.update(metric_dict) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + ["\t".join([f"{m}: {round(metrics[m], 4):.4f}"]) for m in metrics] + ) + ) + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val{extra_suffix}/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val{extra_suffix}/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics diff --git a/audioldm/clap/training/main.py b/audioldm/clap/training/main.py new file mode 100644 index 0000000000000000000000000000000000000000..3b563a5d001be7adfbe779dee7ad8ac49aadc50d --- /dev/null +++ b/audioldm/clap/training/main.py @@ -0,0 +1,596 @@ +from inspect import getargs +import logging +import os +import random +from datetime import datetime +import bisect +import copy +import numpy as np +import torch +import torch.backends.cudnn as cudnn +from torch import optim +from torch.cuda.amp import GradScaler +import faulthandler +import pathlib + +try: + import wandb +except ImportError: + wandb = None + +try: + import torch.utils.tensorboard as tensorboard +except ImportError: + tensorboard = None + +try: + import horovod.torch as hvd +except ImportError: + hvd = None + +from open_clip import create_model_and_transforms, trace_model, create_model +from training.data import get_data +from training.distributed import is_master, init_distributed_device, world_info_from_env +from training.logger import setup_logging +from training.params import parse_args +from training.scheduler import cosine_lr +from training.train import train_one_epoch, evaluate +from open_clip.utils import dataset_split, get_optimizer + + +def maintain_ckpts(args, startidx, all_idx_len): + for i in reversed(range(startidx, all_idx_len)): + if os.path.exists(os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt")): + os.rename( + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + os.path.join(args.checkpoint_path, f"epoch_top_{i+1}.pt"), + ) + if os.path.exists( + os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt") + ): + os.remove(os.path.join(args.checkpoint_path, f"epoch_top_{all_idx_len}.pt")) + return + + +def update_top_k_performance( + new_metrics_inputs, current_top_k_ckpt_metrics, args, ckpt, bignumbetter=True +): + """ + Record the top-k performance of the current epoch. + current_top_k_metrics is a dictionary of the form: {1: top_1_ckpt_measure, 2: top_2_ckpt_measure, ...} + """ + if isinstance(new_metrics_inputs, (list, tuple)): + new_metrics_inputs = np.mean(new_metrics_inputs) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, dict): + new_metrics_inputs = np.mean(list(new_metrics_inputs.values())) + return update_top_k_performance( + new_metrics_inputs, + current_top_k_ckpt_metrics, + args=args, + ckpt=ckpt, + bignumbetter=bignumbetter, + ) + elif isinstance(new_metrics_inputs, (float, int)): + update_flag = {k: False for k in current_top_k_ckpt_metrics.keys()} + sorted_keys = sorted(current_top_k_ckpt_metrics.keys()) + sorted_values = sorted( + current_top_k_ckpt_metrics.values(), reverse=bignumbetter + ) + sorted_values_ = copy.deepcopy(sorted_values) + sorted_values.append(new_metrics_inputs) + sorted_values = sorted(sorted_values, reverse=bignumbetter) + sorted_values = sorted_values[:-1] + + if sorted_values == sorted_values_: + return current_top_k_ckpt_metrics, new_metrics_inputs + else: + for i in range(len(sorted_keys)): + if current_top_k_ckpt_metrics[sorted_keys[i]] != sorted_values[i]: + current_top_k_ckpt_metrics[sorted_keys[i]] = sorted_values[i] + update_flag[sorted_keys[i]] = True + for i in range(len(update_flag)): + if update_flag[i]: + maintain_ckpts(args, i, len(sorted_keys)) + torch.save( + ckpt, + os.path.join(args.checkpoint_path, f"epoch_top_{i}.pt"), + ) + break + return current_top_k_ckpt_metrics, new_metrics_inputs + + +# def updateifNone(a, b): +# a = b if None else a +# return a + + +def is_pretrained_params(n): + return ( + n.startswith("transformer") + or n in ["positional_embedding", "text_projection"] + or n.startswith("token_embedding") + or n.startswith("ln_final") + or n.startswith("logit_scale_t") + ) + + +def random_seed(seed=42, rank=0): + torch.manual_seed(seed + rank) + np.random.seed(seed + rank) + random.seed(seed + rank) + + +def main(): + args = parse_args() + # sanitize model name for filesystem / uri use, easier if we don't use / in name as a rule? + args.amodel = args.amodel.replace("/", "-") + # download sizes.json file + + # (yusong): the below two lines are for debug + # print("setting up faulthandler") + # faulthandler.register(10) + + random.seed(args.seed) + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + np.random.seed(args.seed) + if args.tmodel == "bert" or args.tmodel == "roberta" or args.tmodel == "bart": + assert ( + args.pretrained == "" or args.pretrained is None + ), "bert/roberta/bart text encoder does not support pretrained models." + + # get the name of the experiments + if args.name is None: + args.name = "-".join( + [ + datetime.now().strftime("%Y_%m_%d-%H_%M_%S"), + f"model_{args.amodel}", + f"lr_{args.lr}", + f"b_{args.batch_size}", + f"j_{args.workers}", + f"p_{args.precision}", + ] + ) + + # discover initial world args early so we can log properly + args.distributed = False + args.local_rank, args.rank, args.world_size = world_info_from_env() + + if args.remotedata and is_master(args): + for dataset_name in args.datasetnames: + for split in dataset_split[dataset_name]: + if not os.path.exists(f"./json_files/{dataset_name}/{split}"): + os.makedirs(f"./json_files/{dataset_name}/{split}") + os.system( + f"aws s3 cp s3://s-laion-audio/webdataset_tar/{dataset_name}/{split}/sizes.json ./json_files/{dataset_name}/{split}/sizes.json" + ) + + args.log_path = None + if is_master(args, local=args.log_local): + log_base_path = os.path.join(args.logs, args.name) + os.makedirs(log_base_path, exist_ok=True) + log_filename = f"out-{args.rank}" if args.log_local else "out.log" + args.log_path = os.path.join(log_base_path, log_filename) + if os.path.exists(args.log_path): + print( + "Error. Experiment already exists. Use --name {} to specify a new experiment." + ) + return -1 + + # Set logger + args.log_level = logging.DEBUG if args.debug else logging.INFO + setup_logging(args.log_path, args.log_level) + + # fully initialize distributed device environment + device = init_distributed_device(args) + + args.wandb = "wandb" in args.report_to or "all" in args.report_to + args.tensorboard = "tensorboard" in args.report_to or "all" in args.report_to + if is_master(args): + args.tensorboard_path = ( + os.path.join(args.logs, args.name, "tensorboard") + if args.tensorboard + else "" + ) + args.checkpoint_path = os.path.join(args.logs, args.name, "checkpoints") + for dirname in [args.tensorboard_path, args.checkpoint_path]: + if dirname: + os.makedirs(dirname, exist_ok=True) + else: + args.tensorboard_path = "" + args.checkpoint_path = "" + + if args.copy_codebase: + copy_codebase(args) + + assert args.precision in ["amp", "fp16", "fp32"] + if args.precision == "fp16": + logging.warning( + "It is recommended to use AMP mixed-precision instead of FP16. " + "FP16 support needs further verification and tuning, especially for train." + ) + + if args.horovod: + logging.info( + f"Running in horovod mode with multiple processes / nodes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + elif args.distributed: + logging.info( + f"Running in distributed mode with multiple processes. Device: {args.device}." + f"Process (global: {args.rank}, local {args.local_rank}), total {args.world_size}." + ) + else: + logging.info(f"Running with a single process. Device {args.device}.") + + logging.info(f"openai cache dir: {os.path.expanduser(args.openai_model_cache_dir)}") + + model, model_cfg = create_model( + args.amodel, + args.tmodel, + args.pretrained, + precision=args.precision, + device=device, + jit=args.torchscript, + force_quick_gelu=args.force_quick_gelu, + openai_model_cache_dir=os.path.expanduser(args.openai_model_cache_dir), + skip_params=True, + pretrained_audio=args.pretrained_audio, + pretrained_text=args.pretrained_text, + enable_fusion=args.enable_fusion, + fusion_type=args.fusion_type, + ) + + if args.horovod: + with torch.no_grad(): + for param in model.parameters(): + param.set_(param.contiguous()) + + if args.trace: + model = trace_model(model, batch_size=args.batch_size, device=device) + + if is_master(args): + logging.info("Model:") + logging.info(f"{str(model)}") + logging.info("Params:") + params_file = os.path.join(args.logs, args.name, "params.txt") + with open(params_file, "w") as f: + for name in sorted(vars(args)): + val = getattr(args, name) + logging.info(f" {name}: {val}") + f.write(f"{name}: {val}\n") + + if args.distributed and not args.horovod: + if args.use_bn_sync: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + ddp_args = {} + if args.ddp_static_graph: + # this doesn't exist in older PyTorch, arg only added if enabled + ddp_args["static_graph"] = True + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True, **ddp_args + ) + + data = get_data(args, model_cfg) + assert len(data), "At least one train or eval dataset must be specified." + if args.trace: + assert "train" not in data, "Cannot train with traced model" + + exclude = ( + lambda n, p: p.ndim < 2 + or "bn" in n + or "ln" in n + or "bias" in n + or "logit_scale" in n + ) + include = lambda n, p: not exclude(n, p) + + named_parameters = list(model.named_parameters()) + + # freeze text encoder + text_freeze_parameters = [p for n, p in named_parameters if "text_branch" in n] + + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + + gain_or_bias_params = [ + p for n, p in named_parameters if exclude(n, p) and p.requires_grad + ] + rest_params = [p for n, p in named_parameters if include(n, p) and p.requires_grad] + + # set wd-related params to 0 if use adam optimizer + if args.optimizer == "adam": + args.wd = 0 + args.wd_pretrained = 0 + args.wd_new = 0 + + if args.train_data is None: + optimizer = None + scheduler = None + else: + total_steps = data["train"].dataloader.num_batches * args.epochs + + if args.split_opt: + for x in ["lr", "beta1", "beta2", "eps", "wd"]: + for y in ["_new", "_pretrained"]: + if getattr(args, x + y) is None: + setattr(args, x + y, getattr(args, x)) + + gain_or_bias_pretrained_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + rest_pretrained_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and is_pretrained_params(n) + ] + gain_or_bias_new_params = [ + p + for n, p in named_parameters + if (exclude(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + rest_new_params = [ + p + for n, p in named_parameters + if (include(n, p) and p.requires_grad) and (not is_pretrained_params(n)) + ] + pretrained_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_pretrained_params, "weight_decay": 0.0}, + { + "params": rest_pretrained_params, + "weight_decay": args.wd_pretrained, + }, + ], + lr=args.lr_pretrained, + betas=(args.beta1_pretrained, args.beta2_pretrained), + eps=args.eps_pretrained, + momentum=args.momentum_pretrained, + optimizer_name=args.optimizer, + ) + pretrained_params_scheduler = cosine_lr( + pretrained_params_optimizer, + args.lr_pretrained, + args.warmup, + total_steps, + ) + new_params_optimizer = get_optimizer( + [ + {"params": gain_or_bias_new_params, "weight_decay": 0.0}, + {"params": rest_new_params, "weight_decay": args.wd_new}, + ], + lr=args.lr_new, + betas=(args.beta1_new, args.beta2_new), + eps=args.eps_new, + momentum=args.momentum_new, + optimizer_name=args.optimizer, + ) + + new_params_scheduler = cosine_lr( + new_params_optimizer, args.lr_new, args.warmup, total_steps + ) + + optimizer = { + "pretrained": pretrained_params_optimizer, + "new": new_params_optimizer, + } + scheduler = { + "pretrained": pretrained_params_scheduler, + "new": new_params_scheduler, + } + + if args.horovod: + pretrained_params_optimizer = hvd.DistributedOptimizer( + pretrained_params_optimizer, + named_parameters=model.named_parameters(), + ) + new_params_optimizer = hvd.DistributedOptimizer( + new_params_optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(pretrained_params_optimizer, root_rank=0) + hvd.broadcast_optimizer_state(new_params_optimizer, root_rank=0) + else: + optimizer = get_optimizer( + [ + {"params": gain_or_bias_params, "weight_decay": 0.0}, + {"params": rest_params, "weight_decay": args.wd}, + ], + lr=args.lr, + betas=(args.beta1, args.beta2), + eps=args.eps, + momentum=args.momentum, + optimizer_name=args.optimizer, + ) + + scheduler = cosine_lr(optimizer, args.lr, args.warmup, total_steps) + + if args.horovod: + optimizer = hvd.DistributedOptimizer( + optimizer, named_parameters=model.named_parameters() + ) + hvd.broadcast_parameters(model.state_dict(), root_rank=0) + hvd.broadcast_optimizer_state(optimizer, root_rank=0) + + scaler = GradScaler() if args.precision == "amp" else None + + # optionally resume from a checkpoint + start_epoch = 0 + if args.resume is not None: + if os.path.isfile(args.resume): + checkpoint = torch.load(args.resume, map_location=device) + if "epoch" in checkpoint: + # resuming a train checkpoint w/ epoch and optimizer state + start_epoch = checkpoint["epoch"] + sd = checkpoint["state_dict"] + if not args.distributed and next(iter(sd.items()))[0].startswith( + "module" + ): + sd = {k[len("module.") :]: v for k, v in sd.items()} + model.load_state_dict(sd) + if args.split_opt: + if optimizer is not None: + for k, o_ in optimizer.items(): + o_.load_state_dict(checkpoint[k + "_" + "optimizer"]) + if optimizer is not None: + optimizer.load_state_dict(checkpoint["optimizer"]) + if scaler is not None and "scaler" in checkpoint: + scaler.load_state_dict(checkpoint["scaler"]) + logging.info( + f"=> resuming checkpoint '{args.resume}' (epoch {start_epoch})" + ) + else: + # loading a bare (model only) checkpoint for fine-tune or evaluation + model.load_state_dict(checkpoint) + logging.info( + f"=> loaded checkpoint '{args.resume}' (epoch {start_epoch})" + ) + if args.freeze_text: + print("Freeze Text!!!!") + for k in text_freeze_parameters: + k.requires_grad = False + else: + logging.info("=> no checkpoint found at '{}'".format(args.resume)) + + cudnn.benchmark = True + cudnn.deterministic = False + + # determine if this worker should save logs and checkpoints. only do so if it is rank == 0 + args.save_logs = args.logs and args.logs.lower() != "none" and is_master(args) + writer = None + if args.save_logs and args.tensorboard: + assert tensorboard is not None, "Please install tensorboard." + writer = tensorboard.SummaryWriter(args.tensorboard_path) + + if args.wandb and is_master(args): + assert wandb is not None, "Please install wandb." + logging.debug("Starting wandb.") + args.train_sz = data["train"].dataloader.num_samples + if args.val_data is not None: + args.val_sz = data["val"].dataloader.num_samples + # you will have to configure this for your project! + wandb.init( + project="clap", + notes=args.wandb_notes, + name=args.wandb_notes, + tags=[], + config=vars(args), + ) + if args.debug: + wandb.watch(model, log="all") + wandb.save(params_file) + logging.debug("Finished loading wandb.") + + if "train" not in data: + evaluate(model, data, start_epoch, args, writer) + return + elif start_epoch == 0 and "val" in data and not args.no_eval: + evaluate(model, data, 0, args, writer) + # print(f'rank {args.rank}, Start First Evaluation')# (yusong): for debug + if args.save_top_performance: + current_top_k_ckpt_metrics = { + i: 0 for i in range(args.save_top_performance) + } # initialize the top-k metric for ckpts to 0 + + # print(f'rank {args.rank}, Start Training') # (yusong): for debug + for epoch in range(start_epoch, args.epochs): + # freeze the text param after (include) args.freeze_text_after, this is -1 by default + if epoch == args.freeze_text_after: + print("Text pretrained parameters are freezed since this epoch.") + for k in text_freeze_parameters: + k.requires_grad = False + if is_master(args): + logging.info(f"Start epoch {epoch}") + + train_one_epoch(model, data, epoch, optimizer, scaler, scheduler, args, writer) + completed_epoch = epoch + 1 + + if ( + any(v in data for v in ("val", "imagenet-val", "imagenet-v2")) + and not args.no_eval + ): + metrics = evaluate(model, data, completed_epoch, args, writer) + if args.save_top_performance: + top_k_dataset = args.top_k_checkpoint_select_dataset + top_k_metric = args.top_k_checkpoint_select_metric + filtered_metrics = [ + v + for k, v in metrics.items() + if top_k_metric in k and top_k_dataset in k + ] # check all R@10 metrics (all dataset) and use it to update the ckpt + # Saving checkpoints. + if args.save_logs: + if args.split_opt: + opt_dict = { + k + "_" + "optimizer": v.state_dict() for k, v in optimizer.items() + } + else: + opt_dict = {"optimizer": optimizer.state_dict()} + checkpoint_dict = { + "epoch": completed_epoch, + "name": args.name, + "state_dict": model.state_dict(), + } + checkpoint_dict.update(opt_dict) + if scaler is not None: + checkpoint_dict["scaler"] = scaler.state_dict() + + if completed_epoch == args.epochs or ( + args.save_frequency > 0 and (completed_epoch % args.save_frequency) == 0 + ): + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_{completed_epoch}.pt"), + ) + if args.save_most_recent: + torch.save( + checkpoint_dict, + os.path.join(args.checkpoint_path, f"epoch_latest.pt"), + ) + if args.save_top_performance and not args.no_eval: + update_top_k_performance( + filtered_metrics, + current_top_k_ckpt_metrics, + args, + checkpoint_dict, + bignumbetter=True, + ) + + if args.wandb and is_master(args): + wandb.finish() + + +def copy_codebase(args): + from shutil import copytree, ignore_patterns + + new_code_path = os.path.join(args.logs, args.name, "code") + if os.path.exists(new_code_path): + print( + f"Error. Experiment already exists at {new_code_path}. Use --name to specify a new experiment." + ) + return -1 + print(f"Copying codebase to {new_code_path}") + current_code_path = os.path.realpath(__file__) + for _ in range(3): + current_code_path = os.path.dirname(current_code_path) + copytree( + current_code_path, new_code_path, ignore=ignore_patterns("log", "logs", "wandb") + ) + print("Done copying code.") + return 1 + + +if __name__ == "__main__": + main() diff --git a/audioldm/clap/training/params.py b/audioldm/clap/training/params.py new file mode 100644 index 0000000000000000000000000000000000000000..b1933e3a78ff583733846ea285d56eb0a0b892a5 --- /dev/null +++ b/audioldm/clap/training/params.py @@ -0,0 +1,569 @@ +import argparse +import os + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + "~/.cache") + + + +def get_default_params(model_name): + # Params from paper (https://arxiv.org/pdf/2103.00020.pdf) + model_name = model_name.lower() + if "vit" in model_name: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.98, "eps": 1.0e-6} + else: + return {"lr": 5.0e-4, "beta1": 0.9, "beta2": 0.999, "eps": 1.0e-8} + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--train-data", + type=str, + default=None, + help="Path to h5 filewith training data", + ) + parser.add_argument( + "--val-data", + type=str, + default=None, + help="Path to h5 file with validation data", + ) + parser.add_argument( + "--freeze-text", + default=False, + action="store_true", + help="if you need to freeze the text encoder, make this True", + ) + parser.add_argument( + "--freeze-text-after", + type=int, + default=-1, + help="if you need to freeze the text encoder after (include) epoch x, set this param to x. Set -1 to disable it", + ) + parser.add_argument( + "--train-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in training data", + ) + parser.add_argument( + "--val-ipc", + type=str, + default=None, + help="Path to npy file of the number of instance per class in validation data", + ) + parser.add_argument( + "--train-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Required for webdataset if not available in info file.", + ) + parser.add_argument( + "--val-num-samples", + type=int, + default=None, + help="Number of samples in dataset. Useful for webdataset if not available in info file.", + ) + parser.add_argument( + "--dataset-type", + choices=["webdataset", "csv", "auto", "toy"], + default="auto", + help="Which type of dataset to process.", + ) + parser.add_argument( + "--csv-separator", + type=str, + default="\t", + help="For csv-like datasets, which separator to use.", + ) + parser.add_argument( + "--csv-img-key", + type=str, + default="filepath", + help="For csv-like datasets, the name of the key for the image paths.", + ) + parser.add_argument( + "--csv-caption-key", + type=str, + default="title", + help="For csv-like datasets, the name of the key for the captions.", + ) + parser.add_argument( + "--imagenet-val", + type=str, + default=None, + help="Path to imagenet val set for conducting zero shot evaluation.", + ) + parser.add_argument( + "--imagenet-v2", + type=str, + default=None, + help="Path to imagenet v2 for conducting zero shot evaluation.", + ) + parser.add_argument( + "--datasetnames", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset names to load. Can be some of these: Clotho, audioset, audiocaps, BBCSoundEffects", + ) + parser.add_argument( + "--full-train-dataset", + nargs="+", + default=None, + help="Which dataset will be trained with all the subsets. (train+test)", + ) + parser.add_argument( + "--exclude-eval-dataset", + nargs="+", + default=None, + help="Which dataset will be excluded with evaluation", + ) + parser.add_argument( + "--datasetinfos", + nargs="+", + default=None, + help="If loading webdataset, spedify the dataset types to load. Can be some of these: train, test, valid, unbalanced_train, balanced_train, eval", + ) + parser.add_argument( + "--dataset-proportion", + type=float, + default=1.0, + help="How much proportion of dataset we want to train.", + ) + parser.add_argument( + "--remotedata", + default=False, + action="store_true", + help="if the dataset is remote, set this flag", + ) + parser.add_argument( + "--class-label-path", + type=str, + default=None, + help="The path of the class label pickle or csv.", + ) + parser.add_argument( + "--datasetpath", + type=str, + default="/mnt/audio_clip/webdataset_tar", + help="The path to the dataset", + ) + parser.add_argument( + "--logs", + type=str, + default="./logs/", + help="Where to store tensorboard logs. Use None to avoid storing logs.", + ) + parser.add_argument( + "--log-local", + action="store_true", + default=False, + help="log files on local master, otherwise global master only.", + ) + parser.add_argument( + "--name", + type=str, + default=None, + help="Optional identifier for the experiment when storing logs. Otherwise use current time.", + ) + parser.add_argument( + "--workers", type=int, default=1, help="Number of workers per GPU." + ) + parser.add_argument( + "--batch-size", type=int, default=64, help="Batch size per GPU." + ) + parser.add_argument( + "--epochs", type=int, default=32, help="Number of epochs to train for." + ) + parser.add_argument("--lr", type=float, default=None, help="Learning rate.") + parser.add_argument("--beta1", type=float, default=None, help="Adam beta 1.") + parser.add_argument("--beta2", type=float, default=None, help="Adam beta 2.") + parser.add_argument("--eps", type=float, default=None, help="Adam epsilon.") + parser.add_argument("--momentum", type=float, default=None, help="SGD epsilon.") + parser.add_argument("--wd", type=float, default=0.2, help="Weight decay.") + + parser.add_argument( + "--split-opt", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--lr-pretrained", type=float, default=None, help="Learning rate for text." + ) + parser.add_argument( + "--beta1-pretrained", type=float, default=None, help="Adam beta 1 for text." + ) + parser.add_argument( + "--beta2-pretrained", type=float, default=None, help="Adam beta 2 for text." + ) + parser.add_argument( + "--eps-pretrained", type=float, default=None, help="Adam epsilon for text." + ) + parser.add_argument( + "--wd-pretrained", type=float, default=0.2, help="Weight decay for text." + ) + parser.add_argument( + "--momentum-pretrained", type=float, default=0.9, help="Momentum for text." + ) + parser.add_argument( + "--lr-new", type=float, default=None, help="Learning rate for audio." + ) + parser.add_argument( + "--beta1-new", type=float, default=None, help="Adam beta 1 for audio." + ) + parser.add_argument( + "--beta2-new", type=float, default=None, help="Adam beta 2 for audio." + ) + parser.add_argument( + "--eps-new", type=float, default=None, help="Adam epsilon for audio." + ) + parser.add_argument( + "--wd-new", type=float, default=0.2, help="Weight decay for audio." + ) + parser.add_argument( + "--momentum-new", type=float, default=0.9, help="Momentum for audio." + ) + parser.add_argument( + "--warmup", type=int, default=10000, help="Number of steps to warmup for." + ) + parser.add_argument( + "--use-bn-sync", + default=False, + action="store_true", + help="Whether to use batch norm sync.", + ) + parser.add_argument( + "--skip-scheduler", + action="store_true", + default=False, + help="Use this flag to skip the learning rate decay.", + ) + parser.add_argument( + "--save-frequency", type=int, default=1, help="How often to save checkpoints." + ) + parser.add_argument( + "--save-top-performance", + type=int, + default=0, + help="Save the top x performance weights if the value >0", + ) + parser.add_argument( + "--save-most-recent", + action="store_true", + default=False, + help="Always save the most recent model trained to epoch_latest.pt.", + ) + parser.add_argument( + "--zeroshot-frequency", type=int, default=2, help="How often to run zero shot." + ) + parser.add_argument( + "--val-frequency", + type=int, + default=1, + help="How often to run evaluation with val data.", + ) + parser.add_argument( + "--resume", + default=None, + type=str, + help="path to latest checkpoint (default: none)", + ) + parser.add_argument( + "--precision", + choices=["amp", "fp16", "fp32"], + default="amp", + help="Floating point precision.", + ) + parser.add_argument( + "--amodel", + type=str, + default="RN50", + help="Name of the audio backbone to use.", + ) + parser.add_argument( + "--tmodel", + type=str, + default="transformer", + help="Name of the text backbone to use. Can be [transformer, bert, roberta, bart]", + ) + parser.add_argument( + "--pretrained-audio", + default="", + type=str, + help="Use a pretrained audio model weights for the audio encoder of CLAP", + ) + parser.add_argument( + "--pretrained-text", + default="", + type=str, + help="Use a pretrained text model weights for the text encoder of CLAP", + ) + parser.add_argument( + "--pretrained", + default="", + type=str, + help="Use a pretrained CLIP model weights with the specified tag or file path.", + ) + parser.add_argument( + "--pretrained-image", + default=False, + action="store_true", + help="Load imagenet pretrained weights for image tower backbone if available.", + ) + parser.add_argument( + "--lock-image", + default=False, + action="store_true", + help="Lock full image tower by disabling gradients.", + ) + parser.add_argument( + "--lock-image-unlocked-groups", + type=int, + default=0, + help="Leave last n image tower layer groups unlocked.", + ) + parser.add_argument( + "--lock-image-freeze-bn-stats", + default=False, + action="store_true", + help="Freeze BatchNorm running stats in image tower for any locked layers.", + ) + parser.add_argument( + "--local-loss", + default=False, + action="store_true", + help="calculate loss w/ local features @ global (instead of realizing full global @ global matrix)", + ) + parser.add_argument( + "--gather-with-grad", + default=False, + action="store_true", + help="enable full distributed gradient for feature gather", + ) + parser.add_argument( + "--force-quick-gelu", + default=False, + action="store_true", + help="Force use of QuickGELU activation for non-OpenAI transformer models.", + ) + parser.add_argument( + "--torchscript", + default=False, + action="store_true", + help="torch.jit.script the model, also uses jit version of OpenAI models if pretrained=='openai'", + ) + parser.add_argument( + "--trace", + default=False, + action="store_true", + help="torch.jit.trace the model for inference / eval only", + ) + # arguments for distributed training + parser.add_argument( + "--dist-url", + default="env://", + type=str, + help="url used to set up distributed training", + ) + parser.add_argument( + "--dist-backend", default="nccl", type=str, help="distributed backend" + ) + parser.add_argument( + "--report-to", + default="", + type=str, + help="Options are ['wandb', 'tensorboard', 'wandb,tensorboard']", + ) + parser.add_argument( + "--wandb-notes", default="", type=str, help="Notes if logging with wandb" + ) + parser.add_argument( + "--C", type=float, default=3.16, help="inverse regularizer for logistic reg." + ) + parser.add_argument( + "--debug", + default=False, + action="store_true", + help="If true, more information is logged.", + ) + parser.add_argument( + "--copy-codebase", + default=False, + action="store_true", + help="If true, we copy the entire base on the log diretory, and execute from there.", + ) + parser.add_argument( + "--horovod", + default=False, + action="store_true", + help="Use horovod for distributed training.", + ) + parser.add_argument( + "--ddp-static-graph", + default=False, + action="store_true", + help="Enable static graph optimization for DDP in PyTorch >= 1.11.", + ) + parser.add_argument( + "--no-set-device-rank", + default=False, + action="store_true", + help="Don't set device index from local rank (when CUDA_VISIBLE_DEVICES restricted to one per proc).", + ) + parser.add_argument("--seed", type=int, default=4242, help="Default random seed.") + + parser.add_argument( + "--top-k-checkpoint-select-dataset", + type=str, + default="all", + help="The dataset of selecting top-k checkpoint.", + ) + + # @R10, @R@5, @R1, mAP@10 + parser.add_argument( + "--top-k-checkpoint-select-metric", + type=str, + default="_R@10", + help="The metric for selecting top-k checkpoint.", + ) + parser.add_argument( + "--openai-model-cache-dir", + type=str, + default=f"{CACHE_DIR}/clip", + help="Directory to download OpenAI models.", + ) + parser.add_argument( + "--optimizer", + type=str, + default="adamw", + help="can be AdamW or SGD", + ) + parser.add_argument( + "--parallel-eval", + default=False, + action="store_true", + help="Eval in parallel (multi-GPU, multi-node).", + ) + + parser.add_argument( + "--no-eval", + default=False, + action="store_true", + help="Training without evaluation.", + ) + + parser.add_argument( + "--lp-mlp", + default=False, + action="store_true", + help="Linear Probe using MLP layer or not.", + ) + + parser.add_argument( + "--lp-freeze", + default=False, + action="store_true", + help="Linear Probe using Freeze CLAP or not", + ) + + parser.add_argument( + "--lp-act", + default="None", + type=str, + help="Options are ['relu','elu','prelu','softmax','sigmoid']", + ) + + parser.add_argument( + "--lp-loss", type=str, default="bce", help="Loss func of Linear Probe." + ) + + parser.add_argument( + "--lp-metrics", + type=str, + default="map,mauc,acc", + help="Metrics of Linear Probe.", + ) + + parser.add_argument( + "--lp-lr", type=float, default=1e-4, help="learning rate of linear probe" + ) + parser.add_argument( + "--kappa", + type=float, + default=0, + help="the kappa in the weighted contrastive loss, default is to turn off the weighted contrastive loss", + ) + + parser.add_argument( + "--data-filling", + type=str, + default="pad", + help="type of data filling when the audio length is shorter than the max length." + "Can be one of the following: repeat, repeatpad, pad", + ) + parser.add_argument( + "--data-truncating", + type=str, + default="rand_trunc", + help="type of data truncation when the audio length is longer than the max length." + "Can be one of the following: rand_trunc, fusion", + ) + + parser.add_argument( + "--clap-mlploss", + default=False, + action="store_true", + help="Using MLP loss for CLAP model or not", + ) + + parser.add_argument( + "--wandb-id", + type=str, + default=None, + help="the id of wandb experiment to restore.", + ) + + parser.add_argument( + "--sleep", type=float, default=0, help="sleep n seconds before start training" + ) + + # variable length processing + parser.add_argument( + "--enable-fusion", + default=False, + action="store_true", + help="Enable feature funsion for variable-length data", + ) + + parser.add_argument( + "--fusion-type", + type=str, + default="None", + help="Type is among ['channel_map', 'daf_1d','aff_1d','iaff_1d','daf_2d','aff_2d','iaff_2d']", + ) + + parser.add_argument( + "--mixup", + default=False, + action="store_true", + help="Enable mixup in finetuning training.", + ) + parser.add_argument( + "--text-augment-selection", + type=str, + default=None, + help="For selecting levels of augmented text. Type is among ['all', 'augment_only', 'none']", + ) + + args = parser.parse_args() + + # If some params are not passed, we use the default values based on model name. + default_params = get_default_params(args.amodel) + for name, val in default_params.items(): + if getattr(args, name) is None: + setattr(args, name, val) + + return args diff --git a/audioldm/clap/training/scheduler.py b/audioldm/clap/training/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7151ffbab25a113673b7627027b443b27f22cb0f --- /dev/null +++ b/audioldm/clap/training/scheduler.py @@ -0,0 +1,24 @@ +import numpy as np + + +def assign_learning_rate(optimizer, new_lr): + for param_group in optimizer.param_groups: + param_group["lr"] = new_lr + + +def _warmup_lr(base_lr, warmup_length, step): + return base_lr * (step + 1) / warmup_length + + +def cosine_lr(optimizer, base_lr, warmup_length, steps): + def _lr_adjuster(step): + if step < warmup_length: + lr = _warmup_lr(base_lr, warmup_length, step) + else: + e = step - warmup_length + es = steps - warmup_length + lr = 0.5 * (1 + np.cos(np.pi * e / es)) * base_lr + assign_learning_rate(optimizer, lr) + return lr + + return _lr_adjuster diff --git a/audioldm/clap/training/train.py b/audioldm/clap/training/train.py new file mode 100644 index 0000000000000000000000000000000000000000..f5759c4679d2ee9c0748444adf66b8453cf09728 --- /dev/null +++ b/audioldm/clap/training/train.py @@ -0,0 +1,838 @@ +import json +import logging +import math +import os +import time +from contextlib import suppress + +import numpy as np +import torch +import torch.nn.functional as F + +try: + import wandb +except ImportError: + wandb = None + +from open_clip import ClipLoss, gather_features +from .distributed import is_master +from .zero_shot import zero_shot_eval + + +class AverageMeter(object): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def unwrap_model(model): + if hasattr(model, "module"): + return model.module + else: + return model + + +def train_one_epoch( + model, data, epoch, optimizer, scaler, scheduler, args, tb_writer=None +): + device = torch.device(args.device) + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + model.train() + loss = ClipLoss( + local_loss=args.local_loss, + gather_with_grad=args.gather_with_grad, + cache_labels=True, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + weight_loss_kappa=args.kappa, + ) + + dataloader, sampler = data["train"].dataloader, data["train"].sampler + if args.distributed and sampler is not None: + sampler.set_epoch(epoch) + num_batches_per_epoch = dataloader.num_batches + sample_digits = math.ceil(math.log(dataloader.num_samples + 1, 10)) + + # for toy dataset + if args.dataset_type == "toy": + dataloader.dataset.generate_queue() + + loss_m = AverageMeter() + batch_time_m = AverageMeter() + data_time_m = AverageMeter() + end = time.time() + + for i, batch in enumerate(dataloader): + # logging.info(f"batch {i} of {num_batches_per_epoch}") + step = num_batches_per_epoch * epoch + i + if isinstance(scheduler, dict): + for s in scheduler.values(): + s(step) + else: + scheduler(step) + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + # texts = texts.to(device=device, non_blocking=True) + + data_time_m.update(time.time() - end) + if isinstance(optimizer, dict): + for o_ in optimizer.values(): + o_.zero_grad() + else: + optimizer.zero_grad() + + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.clap_mlploss: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + logit_scale_t=logit_scale_t, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + ) + else: + total_loss = loss( + audio_features=audio_features, + text_features=text_features, + logit_scale_a=logit_scale_a, + ) + if isinstance(optimizer, dict): + if scaler is not None: + scaler.scale(total_loss).backward() + for o_ in optimizer.values(): + if args.horovod: + o_.synchronize() + scaler.unscale_(o_) + with o_.skip_synchronize(): + scaler.step(o_) + else: + scaler.step(o_) + scaler.update() + else: + total_loss.backward() + for o_ in optimizer.values(): + o_.step() + else: + if scaler is not None: + scaler.scale(total_loss).backward() + if args.horovod: + optimizer.synchronize() + scaler.unscale_(optimizer) + with optimizer.skip_synchronize(): + scaler.step(optimizer) + else: + scaler.step(optimizer) + scaler.update() + else: + total_loss.backward() + optimizer.step() + + # Note: we clamp to 4.6052 = ln(100), as in the original paper. + with torch.no_grad(): + unwrap_model(model).logit_scale_a.clamp_(0, math.log(100)) + if args.clap_mlploss: + unwrap_model(model).logit_scale_t.clamp_(0, math.log(100)) + + batch_time_m.update(time.time() - end) + end = time.time() + batch_count = i + 1 + if is_master(args) and (i % 100 == 0 or batch_count == num_batches_per_epoch): + if isinstance(audios, dict): + batch_size = len(audios["waveform"]) + else: + batch_size = len(audios) + num_samples = batch_count * batch_size * args.world_size + samples_per_epoch = dataloader.num_samples + percent_complete = 100.0 * batch_count / num_batches_per_epoch + + # NOTE loss is coarsely sampled, just master node and per log update + loss_m.update(total_loss.item(), batch_size) + logit_scale_scalar_a = logit_scale_a.item() + logit_scale_scalar_t = logit_scale_t.item() + if isinstance(optimizer, dict): + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {[o_.param_groups[0]['lr'] for o_ in optimizer.values()]} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": [o_.param_groups[0]["lr"] for o_ in optimizer.values()], + } + + else: + if args.clap_mlploss: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + f"Logit Scale Text: {logit_scale_scalar_t:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "scale_text": logit_scale_scalar_t, + "lr": optimizer.param_groups[0]["lr"], + } + else: + logging.info( + f"Train Epoch: {epoch} [{num_samples:>{sample_digits}}/{samples_per_epoch} ({percent_complete:.0f}%)] " + f"Loss: {loss_m.val:#.5g} ({loss_m.avg:#.4g}) " + f"Data (t): {data_time_m.avg:.3f} " + f"Batch (t): {batch_time_m.avg:.3f} " + f"LR: {optimizer.param_groups[0]['lr']:5f} " + f"Logit Scale Audio: {logit_scale_scalar_a:.3f}" + ) + + # Save train loss / etc. Using non avg meter values as loggers have their own smoothing + log_data = { + "loss": loss_m.val, + "data_time": data_time_m.val, + "batch_time": batch_time_m.val, + "scale_audio": logit_scale_scalar_a, + "lr": optimizer.param_groups[0]["lr"], + } + for name, val in log_data.items(): + name = "train/" + name + if tb_writer is not None: + tb_writer.add_scalar(name, val, step) + if args.wandb: + assert wandb is not None, "Please install wandb." + wandb.log({name: val, "step": step}) + + # resetting batch / data time meters per log window + batch_time_m.reset() + data_time_m.reset() + # end for + + +def evaluate(model, data, epoch, args, tb_writer=None): + metrics = {} + if not args.parallel_eval: + if not is_master(args): + return metrics + device = torch.device(args.device) + model.eval() + + # CHANGE + # zero_shot_metrics = zero_shot_eval(model, data, epoch, args) + # metrics.update(zero_shot_metrics) + if is_master(args): + print("Evaluating...") + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + if args.val_dataset_names == ["Clotho", "audiocaps"]: + # if only clotho and audiocaps are used, then we will use a different evaluation function. + # This is because in the Clotho and audiocaps valid and test set, there are 5 text for 1 audio. + if args.parallel_eval: + # (yusong): just a hack here. Don't use parallel eval when evaluating only clotho and audiocaps. + raise NotImplementedError( + "Parallel evaluation not supported for eval only Clotho and audiocaps." + ) + val_metrics_per_dataset = evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer + ) + for m in val_metrics_per_dataset.values(): + metrics.update(m) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + metrics = select_top_metric_clotho_audiocaps( + metrics, val_metrics_per_dataset, args + ) + elif "val" in data and ( + args.val_frequency + and ((epoch % args.val_frequency) == 0 or epoch == args.epochs) + ): + dataloader = data["val"].dataloader + num_samples = 0 + samples_per_val = dataloader.num_samples + + # FIXME this does not scale past small eval datasets + # all_audio_features @ all_text_features will blow up memory and compute very quickly + eval_info = {} + if args.clap_mlploss: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } # cumulative_loss = 0.0 + else: + eval_info["all"] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } # cumu + # all_audio_features, all_text_features, all_audio_features_mlp, all_text_features_mlp = [], [], [], [] + with torch.no_grad(): + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + texts = batch["text"] + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + if args.clap_mlploss: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + "all_audio_features_mlp": [], + "all_text_features_mlp": [], + } + else: + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + logit_scale_a, + logit_scale_t, + ) = model(audios, texts, device) + + if args.parallel_eval: + # multi-GPU eval + if args.clap_mlploss: + ( + audio_features, + text_features, + audio_features_mlp, + text_features_mlp, + ) = gather_features( + audio_features=audio_features, + text_features=text_features, + audio_features_mlp=audio_features_mlp, + text_features_mlp=text_features_mlp, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + else: + (audio_features, text_features,) = gather_features( + audio_features=audio_features, + text_features=text_features, + local_loss=False, + gather_with_grad=False, + rank=args.rank, + world_size=args.world_size, + use_horovod=args.horovod, + mlp_loss=args.clap_mlploss, + ) + + if is_master(args): + num_samples += audio_features.shape[0] + for n in [*all_names, "all"]: + if n == "all": + eval_info[n]["all_audio_features"].append( + audio_features.cpu() + ) + eval_info[n]["all_text_features"].append( + text_features.cpu() + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu() + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu() + ) + else: + idx = np.where( + np.array( + [ + "-".join(b.split("/")[-3:-1]) + for b in batch["__url__"] + ] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features"].append( + text_features.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + if args.clap_mlploss: + eval_info[n]["all_audio_features_mlp"].append( + audio_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + eval_info[n]["all_text_features_mlp"].append( + text_features_mlp.cpu().index_select( + 0, torch.tensor(idx).long() + ) + ) + # print(f'eval step {i}') # (yusong): for debug + + # cumulative_loss += total_loss * batch_size + # num_samples += batch_size + if is_master(args) and (i % 100) == 0: # and i != 0: + logging.info( + f"Eval Epoch: {epoch} [{num_samples} / {samples_per_val}]" + ) + if is_master(args): + val_metrics_per_dataset = {} + for n in eval_info.keys(): + if args.clap_mlploss: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + audio_features_mlp=torch.cat( + eval_info[n]["all_audio_features_mlp"] + ), + text_features_mlp=torch.cat( + eval_info[n]["all_text_features_mlp"] + ), + logit_scale_t=logit_scale_t.cpu(), + mlp_loss=args.clap_mlploss, + ) + else: + metrics_single_dataset = get_metrics( + audio_features=torch.cat( + eval_info[n]["all_audio_features"] + ), + text_features=torch.cat(eval_info[n]["all_text_features"]), + logit_scale_a=logit_scale_a.cpu(), + mlp_loss=args.clap_mlploss, + ) + val_metrics_per_dataset[n] = { + n + "/" + k: v for k, v in metrics_single_dataset.items() + } + metrics.update(val_metrics_per_dataset[n]) + if "epoch" not in metrics.keys(): + metrics.update({"epoch": epoch}) + if is_master(args): + if not metrics: + return metrics + + logging.info( + f"Eval Epoch: {epoch} " + + "\n".join( + [ + "\t".join([f"{k}: {round(v, 4):.4f}" for k, v in m.items()]) + for m in val_metrics_per_dataset.values() + ] + ) + ) + + if args.save_logs: + for name, val in metrics.items(): + if tb_writer is not None: + tb_writer.add_scalar(f"val/{name}", val, epoch) + + with open(os.path.join(args.checkpoint_path, "results.jsonl"), "a+") as f: + f.write(json.dumps(metrics)) + f.write("\n") + + if args.wandb: + assert wandb is not None, "Please install wandb." + for name, val in metrics.items(): + wandb.log({f"val/{name}": val, "epoch": epoch}) + + return metrics + else: + return metrics + + +def get_metrics( + audio_features, + text_features, + logit_scale_a, + audio_features_mlp=None, + text_features_mlp=None, + logit_scale_t=None, + mlp_loss=False, +): + metrics = {} + if mlp_loss: + # Set up audio to text & text to audio similary matrice + a_logits_per_audio = ( + (logit_scale_a * audio_features @ text_features_mlp.t()).detach().cpu() + ) + a_logits_per_text = a_logits_per_audio.t().detach().cpu() + t_logits_per_audio = ( + (logit_scale_t * audio_features_mlp @ text_features.t()).detach().cpu() + ) + t_logits_per_text = t_logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(a_logits_per_audio, labels) + + F.cross_entropy(a_logits_per_text, labels) + + F.cross_entropy(t_logits_per_audio, labels) + + F.cross_entropy(t_logits_per_text, labels) + ) / 4 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = { + "audio_to_text": (a_logits_per_audio + t_logits_per_audio) / 2, + "text_to_audio": (a_logits_per_text + t_logits_per_text) / 2, + } + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + else: + # print("text_features", text_features) + # print("text_features.shape", text_features.shape) + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + labels = torch.arange(audio_features.shape[0]).long() + # Change the loss from two terms into four terms with 2x2 combined CE loss + total_loss = ( + F.cross_entropy(logits_per_audio, labels) + + F.cross_entropy(logits_per_text, labels) + ) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + metrics[f"num_samples"] = audio_features.shape[0] + + logits = {"audio_to_text": logits_per_audio, "text_to_audio": logits_per_text} + + ground_truth = torch.arange(len(text_features)).view(-1, 1) + + for name, logit in logits.items(): + ranking = torch.argsort(logit, descending=True) + preds = torch.where(ranking == ground_truth)[ + 1 + ] # (yusong) this line is slow because it uses single thread + preds = preds.detach().cpu().numpy() + metrics[f"{name}_mean_rank"] = preds.mean() + 1 + metrics[f"{name}_median_rank"] = np.floor(np.median(preds)) + 1 + for k in [1, 5, 10]: + metrics[f"{name}_R@{k}"] = np.mean(preds < k) + # map@10 + metrics[f"{name}_mAP@10"] = np.mean(np.where(preds < 10, 1 / (preds + 1), 0.0)) + + return metrics + + +def evaluate_clotho_audiocaps( + model, data, epoch, args, autocast, device, tb_writer=None +): + """ + Adapted from https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py. + 1. for text-to-audio retrieval, do 5 times and average the results + 2. for R@1, R@5, R@10 in audio-to-text retrieval, take the best rank among 5 text + 3. for map@10 in audio-to-text retrieval: + 3.1: sort the rank of 5 text + 3.2: exclude the rank >=10 (0-index) + 3.3: compute the map regarding the remaining ranks: np.mean(np.arange(1, len(ranks)+1) / ranks). + (3.3) That is, take the top ranks of 5 text that is < 10, and assign the descending number as ground truth. + (3.3) E.g.: the ground truth of first rank of the 5 text should be 1, the second rank should be 2, etc. + """ + # TODO: (yusong) only support single GPU evaluation and only support non-mlp case for now. + dataloader = data["val"].dataloader + with torch.no_grad(): + eval_info = {} + for i, batch in enumerate(dataloader): + audios = batch # contains mel_spec, wavform, and longer list + + # each item in the list has 5 texts + if args.tmodel == "transformer": + from open_clip import tokenize + + texts = [tokenize(t) for t in batch["full_text"]] + texts = torch.cat(texts) + else: + from .data import tokenizer + + texts = [ + tokenizer(t) for t in batch["full_text"] + ] # 5 texts for each audio + texts = { + k: torch.cat([t[k] for t in texts]) for k in texts[0].keys() + } # 5 x batch + + # audios = audios.to(device=device, non_blocking=True) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for name in all_names: + if name not in eval_info.keys(): + # we will not use mlp outputs even if args.clap_mlploss=True + eval_info[name] = { + "cumulative_loss": 0.0, + "num_samples": 0, + "all_audio_features": [], + "all_text_features": [], + } + with autocast(): + audio_features = model(audios, None, device) + text_features = model(None, texts, device) + audio_features = F.normalize(audio_features, dim=-1) + text_features = F.normalize(text_features, dim=-1) + + all_names = list( + set(["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]]) + ) + for n in all_names: + idx = np.where( + np.array( + ["-".join(b.split("/")[-3:-1]) for b in batch["__url__"]] + ) + == n + )[0] + eval_info[n]["all_audio_features"].append( + audio_features.cpu().index_select(0, torch.tensor(idx).long()) + ) + # (yusong) please double-check. This is for selecting 5 text features at once. + # because idx is a list of indices in size of num_samples, + # and text_features is a tensor of size (5*num_samples, dim) + # so we need to select 5 consecutive indices at once for a single index in idx. + eval_info[n]["all_text_features"].append( + text_features.cpu() + .reshape([-1, 5, text_features.shape[1]]) + .index_select(0, torch.tensor(idx).long()) + .reshape([-1, text_features.shape[1]]) + ) + + val_metrics_all = {} + + for n in eval_info.keys(): + logit_scale_a, logit_scale_t = model(None, None, device) + logit_scale_a = logit_scale_a.cpu() + + audio_features = torch.cat(eval_info[n]["all_audio_features"], dim=0) + text_features = torch.cat(eval_info[n]["all_text_features"], dim=0) + + logits_per_audio = ( + (logit_scale_a * audio_features @ text_features.t()).detach().cpu() + ) + logits_per_text = logits_per_audio.t().detach().cpu() + + # logits_per_audio shape: [num_samples, num_samples*5] + # logits_per_text shape: [num_samples*5, num_samples] + + logging.info( + f"dataset {n}, logits_per_audio shape: {logits_per_audio.shape}, " + f"logits_per_text shape: {logits_per_text.shape}" + ) + + metrics = {} + num_samples = audio_features.shape[0] + metrics[f"num_samples"] = num_samples + + # (yusong) the following code is very important, please double-check: + # logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d] + # logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + # Those two are retrieving one of the 5 text for each audio. + labels = torch.arange(audio_features.shape[0]).long() + audio_to_text_loss = [ + F.cross_entropy( + logits_per_audio.reshape(num_samples, num_samples, 5)[:, :, d], + labels, + ) + for d in range(5) + ] + text_to_audio_loss = [ + F.cross_entropy( + logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :], + labels, + ) + for d in range(5) + ] + total_loss = (np.mean(audio_to_text_loss) + np.mean(text_to_audio_loss)) / 2 + + metrics[f"cumulative_loss"] = total_loss.item() + + # text to audio: do 5 times + pred_text = [] + for d in range(5): + logit = logits_per_text.reshape(num_samples, 5, num_samples)[:, d, :] + ground_truth = torch.arange(len(logit)).view(-1, 1) + ranking = torch.argsort( + logit, descending=True + ) # [num_samples, num_samples] + preds = torch.where(ranking == ground_truth)[1] + pred_text.append(preds.detach().cpu().numpy()) + pred_text_concat = np.concatenate(pred_text, axis=0) # [5*num_samples] + metrics[f"text_to_audio_mean_rank"] = pred_text_concat.mean() + 1 + metrics[f"text_to_audio_median_rank"] = ( + np.floor(np.median(pred_text_concat)) + 1 + ) + for k in [1, 5, 10]: + metrics[f"text_to_audio_R@{k}"] = np.mean(pred_text_concat < k) + # map@10 + metrics[f"text_to_audio_mAP@10"] = np.mean( + np.where(pred_text_concat < 10, 1 / (pred_text_concat + 1), 0.0) + ) + + # audio to text: take the best result + # for audio to text map 10, sort and assign descending ground truth. + # see https://github.com/XinhaoMei/audio-text_retrieval/blob/main/tools/utils.py#L103 + # map@10 + map_all = [] + pred_audio_all = [] + for d in range(num_samples): + # logits_per_audio: [num_samples, num_samples*5] + logit_single = logits_per_audio[d, :] # [5*num_samples] + # Ground-truth index: [d*5, d*5+1, d*5+2, d*5+3, d*5+4] + ranking = torch.argsort( + logit_single, descending=True + ) # [5*num_samples] + # ranking: the index of first match, second match, ... + ground_truth = torch.arange(d * 5, d * 5 + 5)[None] + all_pred = torch.where( + torch.stack([ranking] * 5) == ground_truth.view(-1, 1) + )[1] + min_pred = torch.min(all_pred) + pred_audio_all.append(min_pred.detach().cpu().numpy()) + all_pred_filter = all_pred[all_pred < 10].detach().cpu().numpy() + # /5 because we have 5 text, so it means for the text rank >=10 we count as 0. + map_single = ( + np.sum( + (np.arange(1, len(all_pred_filter) + 1) / (all_pred_filter + 1)) + ) + / 5 + ) + map_all.append(map_single) + metrics[f"audio_to_text_mAP@10"] = np.mean(map_all) + for k in [1, 5, 10]: + metrics[f"audio_to_text_R@{k}"] = np.mean(np.array(pred_audio_all) < k) + + val_metrics_all[n] = {n + "/" + k: v for k, v in metrics.items()} + return val_metrics_all + + +def calculate_selection_performance_clotho_audiocaps(val_metrics_per_dataset): + """ + Calculate performance for Clotho+AudioCaps for model selection. + """ + selection_performance_all = [] + for n in val_metrics_per_dataset.keys(): + selection_performance = ( + val_metrics_per_dataset[n][f"{n}/audio_to_text_mAP@10"] + + val_metrics_per_dataset[n][f"{n}/text_to_audio_mAP@10"] + ) / 2 + selection_performance_all.append(selection_performance) + return np.mean(selection_performance_all) + + +def select_top_metric_clotho_audiocaps(metrics, val_metrics_per_dataset, args): + # val_metrics_per_dataset: dict, key: dataset name, value: dict, key: metric name, value: metric value + # metrics: dict, key: metric name, value: metric value + # Hack: use args to save the top performance + if not hasattr(args, "top_selection_performance"): + selection_performance = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + # TODO: write the if and else together + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance + else: + selection_performance_new = calculate_selection_performance_clotho_audiocaps( + val_metrics_per_dataset + ) + selection_performance_old = args.top_selection_performance + if selection_performance_new > selection_performance_old: + metric_update = {} + for n in val_metrics_per_dataset.keys(): + for k in val_metrics_per_dataset[n].keys(): + metric_update[ + k.split("/")[0] + "-top" + "/" + k.split("/")[1] + ] = val_metrics_per_dataset[n][k] + metric_update["top_selection_performance"] = selection_performance_new + metric_update["top-selection-epoch"] = metrics["epoch"] + metrics.update(metric_update) + args.top_metric = metric_update + args.top_selection_performance = selection_performance_new + else: + metrics.update(args.top_metric) + return metrics diff --git a/audioldm/clap/training/zero_shot.py b/audioldm/clap/training/zero_shot.py new file mode 100644 index 0000000000000000000000000000000000000000..28b8fccc1af17fc69002857a7f529ac041c374f2 --- /dev/null +++ b/audioldm/clap/training/zero_shot.py @@ -0,0 +1,95 @@ +# NOTE: This script is currently not supported for CLAP. +import logging +from contextlib import suppress + +import torch +import torch.nn.functional as F +from tqdm import tqdm + +from open_clip import tokenize +from .imagenet_zeroshot_data import imagenet_classnames, openai_imagenet_template + + +def zero_shot_classifier(model, classnames, templates, args): + with torch.no_grad(): + zeroshot_weights = [] + for classname in tqdm(classnames): + texts = [template(classname) for template in templates] # format with class + texts = tokenize(texts).to(args.device) # tokenize + if args.distributed and not args.horovod: + class_embeddings = model.module.encode_text(texts) + else: + class_embeddings = model.encode_text(texts) + class_embedding = F.normalize(class_embeddings, dim=-1).mean(dim=0) + class_embedding /= class_embedding.norm() + zeroshot_weights.append(class_embedding) + zeroshot_weights = torch.stack(zeroshot_weights, dim=1).to(args.device) + return zeroshot_weights + + +def accuracy(output, target, topk=(1,)): + pred = output.topk(max(topk), 1, True, True)[1].t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) + return [ + float(correct[:k].reshape(-1).float().sum(0, keepdim=True).cpu().numpy()) + for k in topk + ] + + +def run(model, classifier, dataloader, args): + autocast = torch.cuda.amp.autocast if args.precision == "amp" else suppress + with torch.no_grad(): + top1, top5, n = 0.0, 0.0, 0.0 + for images, target in tqdm(dataloader, unit_scale=args.batch_size): + images = images.to(args.device) + target = target.to(args.device) + + with autocast(): + # predict + if args.distributed and not args.horovod: + image_features = model.module.encode_image(images) + else: + image_features = model.encode_image(images) + image_features = F.normalize(image_features, dim=-1) + logits = 100.0 * image_features @ classifier + + # measure accuracy + acc1, acc5 = accuracy(logits, target, topk=(1, 5)) + top1 += acc1 + top5 += acc5 + n += images.size(0) + + top1 = top1 / n + top5 = top5 / n + return top1, top5 + + +def zero_shot_eval(model, data, epoch, args): + if "imagenet-val" not in data and "imagenet-v2" not in data: + return {} + if args.zeroshot_frequency == 0: + return {} + if (epoch % args.zeroshot_frequency) != 0 and epoch != args.epochs: + return {} + + logging.info("Starting zero-shot imagenet.") + + logging.info("Building zero-shot classifier") + classifier = zero_shot_classifier( + model, imagenet_classnames, openai_imagenet_template, args + ) + + logging.info("Using classifier") + results = {} + if "imagenet-val" in data: + top1, top5 = run(model, classifier, data["imagenet-val"].dataloader, args) + results["imagenet-zeroshot-val-top1"] = top1 + results["imagenet-zeroshot-val-top5"] = top5 + if "imagenet-v2" in data: + top1, top5 = run(model, classifier, data["imagenet-v2"].dataloader, args) + results["imagenetv2-zeroshot-val-top1"] = top1 + results["imagenetv2-zeroshot-val-top5"] = top5 + + logging.info("Finished zero-shot imagenet.") + + return results diff --git a/audioldm/hifigan/__init__.py b/audioldm/hifigan/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e0ae476fe58c48e998c56234a55b871beba4042d --- /dev/null +++ b/audioldm/hifigan/__init__.py @@ -0,0 +1,7 @@ +from .models import Generator + + +class AttrDict(dict): + def __init__(self, *args, **kwargs): + super(AttrDict, self).__init__(*args, **kwargs) + self.__dict__ = self diff --git a/audioldm/hifigan/__pycache__/__init__.cpython-310.pyc b/audioldm/hifigan/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..82a66c070a05db0798068d750b2e9741af425fca Binary files /dev/null and b/audioldm/hifigan/__pycache__/__init__.cpython-310.pyc differ diff --git a/audioldm/hifigan/__pycache__/models.cpython-310.pyc b/audioldm/hifigan/__pycache__/models.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0b43c7481ef23bec3e301cb134dbe0ac8848d6e0 Binary files /dev/null and b/audioldm/hifigan/__pycache__/models.cpython-310.pyc differ diff --git a/audioldm/hifigan/__pycache__/utilities.cpython-310.pyc b/audioldm/hifigan/__pycache__/utilities.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..03b513236472f512a7337eddecc121db5c99d9ca Binary files /dev/null and b/audioldm/hifigan/__pycache__/utilities.cpython-310.pyc differ diff --git a/audioldm/hifigan/models.py b/audioldm/hifigan/models.py new file mode 100644 index 0000000000000000000000000000000000000000..c4382cc39de0463f9b7c0f33f037dbc233e7cb36 --- /dev/null +++ b/audioldm/hifigan/models.py @@ -0,0 +1,174 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch.nn import Conv1d, ConvTranspose1d +from torch.nn.utils import weight_norm, remove_weight_norm + +LRELU_SLOPE = 0.1 + + +def init_weights(m, mean=0.0, std=0.01): + classname = m.__class__.__name__ + if classname.find("Conv") != -1: + m.weight.data.normal_(mean, std) + + +def get_padding(kernel_size, dilation=1): + return int((kernel_size * dilation - dilation) / 2) + + +class ResBlock(torch.nn.Module): + def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)): + super(ResBlock, self).__init__() + self.h = h + self.convs1 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[0], + padding=get_padding(kernel_size, dilation[0]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[1], + padding=get_padding(kernel_size, dilation[1]), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=dilation[2], + padding=get_padding(kernel_size, dilation[2]), + ) + ), + ] + ) + self.convs1.apply(init_weights) + + self.convs2 = nn.ModuleList( + [ + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + weight_norm( + Conv1d( + channels, + channels, + kernel_size, + 1, + dilation=1, + padding=get_padding(kernel_size, 1), + ) + ), + ] + ) + self.convs2.apply(init_weights) + + def forward(self, x): + for c1, c2 in zip(self.convs1, self.convs2): + xt = F.leaky_relu(x, LRELU_SLOPE) + xt = c1(xt) + xt = F.leaky_relu(xt, LRELU_SLOPE) + xt = c2(xt) + x = xt + x + return x + + def remove_weight_norm(self): + for l in self.convs1: + remove_weight_norm(l) + for l in self.convs2: + remove_weight_norm(l) + + +class Generator(torch.nn.Module): + def __init__(self, h): + super(Generator, self).__init__() + self.h = h + self.num_kernels = len(h.resblock_kernel_sizes) + self.num_upsamples = len(h.upsample_rates) + self.conv_pre = weight_norm( + Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3) + ) + resblock = ResBlock + + self.ups = nn.ModuleList() + for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)): + self.ups.append( + weight_norm( + ConvTranspose1d( + h.upsample_initial_channel // (2**i), + h.upsample_initial_channel // (2 ** (i + 1)), + k, + u, + padding=(k - u) // 2, + ) + ) + ) + + self.resblocks = nn.ModuleList() + for i in range(len(self.ups)): + ch = h.upsample_initial_channel // (2 ** (i + 1)) + for j, (k, d) in enumerate( + zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes) + ): + self.resblocks.append(resblock(h, ch, k, d)) + + self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3)) + self.ups.apply(init_weights) + self.conv_post.apply(init_weights) + + def forward(self, x): + x = self.conv_pre(x) + for i in range(self.num_upsamples): + x = F.leaky_relu(x, LRELU_SLOPE) + x = self.ups[i](x) + xs = None + for j in range(self.num_kernels): + if xs is None: + xs = self.resblocks[i * self.num_kernels + j](x) + else: + xs += self.resblocks[i * self.num_kernels + j](x) + x = xs / self.num_kernels + x = F.leaky_relu(x) + x = self.conv_post(x) + x = torch.tanh(x) + + return x + + def remove_weight_norm(self): + # print("Removing weight norm...") + for l in self.ups: + remove_weight_norm(l) + for l in self.resblocks: + l.remove_weight_norm() + remove_weight_norm(self.conv_pre) + remove_weight_norm(self.conv_post) diff --git a/audioldm/hifigan/utilities.py b/audioldm/hifigan/utilities.py new file mode 100644 index 0000000000000000000000000000000000000000..ea9f958e460a77fd4936a6edf59403dd3ea617ab --- /dev/null +++ b/audioldm/hifigan/utilities.py @@ -0,0 +1,86 @@ +import os +import json + +import torch +import numpy as np + +import audioldm.hifigan as hifigan + +HIFIGAN_16K_64 = { + "resblock": "1", + "num_gpus": 6, + "batch_size": 16, + "learning_rate": 0.0002, + "adam_b1": 0.8, + "adam_b2": 0.99, + "lr_decay": 0.999, + "seed": 1234, + "upsample_rates": [5, 4, 2, 2, 2], + "upsample_kernel_sizes": [16, 16, 8, 4, 4], + "upsample_initial_channel": 1024, + "resblock_kernel_sizes": [3, 7, 11], + "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], + "segment_size": 8192, + "num_mels": 64, + "num_freq": 1025, + "n_fft": 1024, + "hop_size": 160, + "win_size": 1024, + "sampling_rate": 16000, + "fmin": 0, + "fmax": 8000, + "fmax_for_loss": None, + "num_workers": 4, + "dist_config": { + "dist_backend": "nccl", + "dist_url": "tcp://localhost:54321", + "world_size": 1, + }, +} + + +def get_available_checkpoint_keys(model, ckpt): + print("==> Attemp to reload from %s" % ckpt) + state_dict = torch.load(ckpt)["state_dict"] + current_state_dict = model.state_dict() + new_state_dict = {} + for k in state_dict.keys(): + if ( + k in current_state_dict.keys() + and current_state_dict[k].size() == state_dict[k].size() + ): + new_state_dict[k] = state_dict[k] + else: + print("==> WARNING: Skipping %s" % k) + print( + "%s out of %s keys are matched" + % (len(new_state_dict.keys()), len(state_dict.keys())) + ) + return new_state_dict + + +def get_param_num(model): + num_param = sum(param.numel() for param in model.parameters()) + return num_param + + +def get_vocoder(config, device): + config = hifigan.AttrDict(HIFIGAN_16K_64) + vocoder = hifigan.Generator(config) + vocoder.eval() + vocoder.remove_weight_norm() + vocoder.to(device) + return vocoder + + +def vocoder_infer(mels, vocoder, lengths=None): + vocoder.eval() + with torch.no_grad(): + wavs = vocoder(mels).squeeze(1) + + wavs = (wavs.cpu().numpy() * 32768).astype("int16") + + if lengths is not None: + wavs = wavs[:, :lengths] + + return wavs diff --git a/audioldm/latent_diffusion/__init__.py b/audioldm/latent_diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bf62ec3293e92b5756b960c8151669368552678a Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc differ diff --git a/audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f72f332ef322e86806ac7220ddf9353cbbfca458 Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc differ diff --git a/audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..552143c13a0bb28a7592701e0ba98c7948621225 Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc differ diff --git a/audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a5573eecc4fd0c7eb05d97eef904e0267931d9b3 Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc differ diff --git a/audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4a96ed73325d7c8b7191cb59f64e4701764f1687 Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc differ diff --git a/audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc b/audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..52110072bae06e7926cefd635634eafcd8d1cba0 Binary files /dev/null and b/audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc differ diff --git a/audioldm/latent_diffusion/attention.py b/audioldm/latent_diffusion/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..27886f5ee3c7eb856100503b838399106ef00051 --- /dev/null +++ b/audioldm/latent_diffusion/attention.py @@ -0,0 +1,469 @@ +from inspect import isfunction +import math +import torch +import torch.nn.functional as F +from torch import nn +from einops import rearrange + +from audioldm.latent_diffusion.util import checkpoint + + +def exists(val): + return val is not None + + +def uniq(arr): + return {el: True for el in arr}.keys() + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def max_neg_value(t): + return -torch.finfo(t.dtype).max + + +def init_(tensor): + dim = tensor.shape[-1] + std = 1 / math.sqrt(dim) + tensor.uniform_(-std, std) + return tensor + + +# feedforward +class GEGLU(nn.Module): + def __init__(self, dim_in, dim_out): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def forward(self, x): + x, gate = self.proj(x).chunk(2, dim=-1) + return x * F.gelu(gate) + + +class FeedForward(nn.Module): + def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0): + super().__init__() + inner_dim = int(dim * mult) + dim_out = default(dim_out, dim) + project_in = ( + nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU()) + if not glu + else GEGLU(dim, inner_dim) + ) + + self.net = nn.Sequential( + project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out) + ) + + def forward(self, x): + return self.net(x) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def Normalize(in_channels): + return torch.nn.GroupNorm( + num_groups=32, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class LinearAttention(nn.Module): + def __init__(self, dim, heads=4, dim_head=32): + super().__init__() + self.heads = heads + hidden_dim = dim_head * heads + self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False) + self.to_out = nn.Conv2d(hidden_dim, dim, 1) + + def forward(self, x): + b, c, h, w = x.shape + qkv = self.to_qkv(x) + q, k, v = rearrange( + qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3 + ) + k = k.softmax(dim=-1) + context = torch.einsum("bhdn,bhen->bhde", k, v) + out = torch.einsum("bhde,bhdn->bhen", context, q) + out = rearrange( + out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w + ) + return self.to_out(out) + + +class SpatialSelfAttention(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = rearrange(q, "b c h w -> b (h w) c") + k = rearrange(k, "b c h w -> b c (h w)") + w_ = torch.einsum("bij,bjk->bik", q, k) + + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = rearrange(v, "b c h w -> b c (h w)") + w_ = rearrange(w_, "b i j -> b j i") + h_ = torch.einsum("bij,bjk->bik", v, w_) + h_ = rearrange(h_, "b c (h w) -> b c h w", h=h) + h_ = self.proj_out(h_) + + return x + h_ + + +class CrossAttention(nn.Module): + """ + ### Cross Attention Layer + This falls-back to self-attention when conditional embeddings are not specified. + """ + + # use_flash_attention: bool = True + use_flash_attention: bool = False + + def __init__( + self, + query_dim, + context_dim=None, + heads=8, + dim_head=64, + dropout=0.0, + is_inplace: bool = True, + ): + # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True): + """ + :param d_model: is the input embedding size + :param n_heads: is the number of attention heads + :param d_head: is the size of a attention head + :param d_cond: is the size of the conditional embeddings + :param is_inplace: specifies whether to perform the attention softmax computation inplace to + save memory + """ + super().__init__() + + self.is_inplace = is_inplace + self.n_heads = heads + self.d_head = dim_head + + # Attention scaling factor + self.scale = dim_head**-0.5 + + # The normal self-attention layer + if context_dim is None: + context_dim = query_dim + + # Query, key and value mappings + d_attn = dim_head * heads + self.to_q = nn.Linear(query_dim, d_attn, bias=False) + self.to_k = nn.Linear(context_dim, d_attn, bias=False) + self.to_v = nn.Linear(context_dim, d_attn, bias=False) + + # Final linear layer + self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout)) + + # Setup [flash attention](https://github.com/HazyResearch/flash-attention). + # Flash attention is only used if it's installed + # and `CrossAttention.use_flash_attention` is set to `True`. + try: + # You can install flash attention by cloning their Github repo, + # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention) + # and then running `python setup.py install` + from flash_attn.flash_attention import FlashAttention + + self.flash = FlashAttention() + # Set the scale for scaled dot-product attention. + self.flash.softmax_scale = self.scale + # Set to `None` if it's not installed + except ImportError: + self.flash = None + + def forward(self, x, context=None, mask=None): + """ + :param x: are the input embeddings of shape `[batch_size, height * width, d_model]` + :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]` + """ + + # If `cond` is `None` we perform self attention + has_cond = context is not None + if not has_cond: + context = x + + # Get query, key and value vectors + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + # Use flash attention if it's available and the head size is less than or equal to `128` + if ( + CrossAttention.use_flash_attention + and self.flash is not None + and not has_cond + and self.d_head <= 128 + ): + return self.flash_attention(q, k, v) + # Otherwise, fallback to normal attention + else: + return self.normal_attention(q, k, v) + + def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Flash Attention + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Get batch size and number of elements along sequence axis (`width * height`) + batch_size, seq_len, _ = q.shape + + # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of + # shape `[batch_size, seq_len, 3, n_heads * d_head]` + qkv = torch.stack((q, k, v), dim=2) + # Split the heads + qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head) + + # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to + # fit this size. + if self.d_head <= 32: + pad = 32 - self.d_head + elif self.d_head <= 64: + pad = 64 - self.d_head + elif self.d_head <= 128: + pad = 128 - self.d_head + else: + raise ValueError(f"Head size ${self.d_head} too large for Flash Attention") + + # Pad the heads + if pad: + qkv = torch.cat( + (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1 + ) + + # Compute attention + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]` + # TODO here I add the dtype changing + out, _ = self.flash(qkv.type(torch.float16)) + # Truncate the extra head size + out = out[:, :, :, : self.d_head].float() + # Reshape to `[batch_size, seq_len, n_heads * d_head]` + out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head) + + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor): + """ + #### Normal Attention + + :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]` + """ + + # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]` + q = q.view(*q.shape[:2], self.n_heads, -1) # [bs, 64, 20, 32] + k = k.view(*k.shape[:2], self.n_heads, -1) # [bs, 1, 20, 32] + v = v.view(*v.shape[:2], self.n_heads, -1) + + # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$ + attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale + + # Compute softmax + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$ + if self.is_inplace: + half = attn.shape[0] // 2 + attn[half:] = attn[half:].softmax(dim=-1) + attn[:half] = attn[:half].softmax(dim=-1) + else: + attn = attn.softmax(dim=-1) + + # Compute attention output + # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$ + # attn: [bs, 20, 64, 1] + # v: [bs, 1, 20, 32] + out = torch.einsum("bhij,bjhd->bihd", attn, v) + # Reshape to `[batch_size, height * width, n_heads * d_head]` + out = out.reshape(*out.shape[:2], -1) + # Map to `[batch_size, height * width, d_model]` with a linear layer + return self.to_out(out) + + +# class CrossAttention(nn.Module): +# def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.): +# super().__init__() +# inner_dim = dim_head * heads +# context_dim = default(context_dim, query_dim) + +# self.scale = dim_head ** -0.5 +# self.heads = heads + +# self.to_q = nn.Linear(query_dim, inner_dim, bias=False) +# self.to_k = nn.Linear(context_dim, inner_dim, bias=False) +# self.to_v = nn.Linear(context_dim, inner_dim, bias=False) + +# self.to_out = nn.Sequential( +# nn.Linear(inner_dim, query_dim), +# nn.Dropout(dropout) +# ) + +# def forward(self, x, context=None, mask=None): +# h = self.heads + +# q = self.to_q(x) +# context = default(context, x) +# k = self.to_k(context) +# v = self.to_v(context) + +# q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + +# sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + +# if exists(mask): +# mask = rearrange(mask, 'b ... -> b (...)') +# max_neg_value = -torch.finfo(sim.dtype).max +# mask = repeat(mask, 'b j -> (b h) () j', h=h) +# sim.masked_fill_(~mask, max_neg_value) + +# # attention, what we cannot get enough of +# attn = sim.softmax(dim=-1) + +# out = einsum('b i j, b j d -> b i d', attn, v) +# out = rearrange(out, '(b h) n d -> b n (h d)', h=h) +# return self.to_out(out) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim, + n_heads, + d_head, + dropout=0.0, + context_dim=None, + gated_ff=True, + checkpoint=True, + ): + super().__init__() + self.attn1 = CrossAttention( + query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff) + self.attn2 = CrossAttention( + query_dim=dim, + context_dim=context_dim, + heads=n_heads, + dim_head=d_head, + dropout=dropout, + ) # is self-attn if context is none + self.norm1 = nn.LayerNorm(dim) + self.norm2 = nn.LayerNorm(dim) + self.norm3 = nn.LayerNorm(dim) + self.checkpoint = checkpoint + + def forward(self, x, context=None): + if context is None: + return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint) + else: + return checkpoint( + self._forward, (x, context), self.parameters(), self.checkpoint + ) + + def _forward(self, x, context=None): + x = self.attn1(self.norm1(x)) + x + x = self.attn2(self.norm2(x), context=context) + x + x = self.ff(self.norm3(x)) + x + return x + + +class SpatialTransformer(nn.Module): + """ + Transformer block for image-like data. + First, project the input (aka embedding) + and reshape to b, t, d. + Then apply standard transformer action. + Finally, reshape to image + """ + + def __init__( + self, + in_channels, + n_heads, + d_head, + depth=1, + dropout=0.0, + context_dim=None, + no_context=False, + ): + super().__init__() + + if no_context: + context_dim = None + + self.in_channels = in_channels + inner_dim = n_heads * d_head + self.norm = Normalize(in_channels) + + self.proj_in = nn.Conv2d( + in_channels, inner_dim, kernel_size=1, stride=1, padding=0 + ) + + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim + ) + for d in range(depth) + ] + ) + + self.proj_out = zero_module( + nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + ) + + def forward(self, x, context=None): + # note: if no context is given, cross-attention defaults to self-attention + b, c, h, w = x.shape + x_in = x + x = self.norm(x) + x = self.proj_in(x) + x = rearrange(x, "b c h w -> b (h w) c") + for block in self.transformer_blocks: + x = block(x, context=context) + x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w) + x = self.proj_out(x) + return x + x_in diff --git a/audioldm/latent_diffusion/ddim.py b/audioldm/latent_diffusion/ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..732002b048e9a193313aa0ef9a353d4fc078be72 --- /dev/null +++ b/audioldm/latent_diffusion/ddim.py @@ -0,0 +1,377 @@ +"""SAMPLING ONLY.""" + +import torch +import numpy as np +from tqdm import tqdm + +from audioldm.latent_diffusion.util import ( + make_ddim_sampling_parameters, + make_ddim_timesteps, + noise_like, + extract_into_tensor, +) + + +class DDIMSampler(object): + def __init__(self, model, schedule="linear", **kwargs): + super().__init__() + self.model = model + self.ddpm_num_timesteps = model.num_timesteps + self.schedule = schedule + + def register_buffer(self, name, attr): + if type(attr) == torch.Tensor: + if attr.device != torch.device("cuda"): + attr = attr.to(torch.device("cuda")) + setattr(self, name, attr) + + def make_schedule( + self, ddim_num_steps, ddim_discretize="uniform", ddim_eta=0.0, verbose=True + ): + self.ddim_timesteps = make_ddim_timesteps( + ddim_discr_method=ddim_discretize, + num_ddim_timesteps=ddim_num_steps, + num_ddpm_timesteps=self.ddpm_num_timesteps, + verbose=verbose, + ) + alphas_cumprod = self.model.alphas_cumprod + assert ( + alphas_cumprod.shape[0] == self.ddpm_num_timesteps + ), "alphas have to be defined for each timestep" + to_torch = lambda x: x.clone().detach().to(torch.float32).to(self.model.device) + + self.register_buffer("betas", to_torch(self.model.betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer( + "alphas_cumprod_prev", to_torch(self.model.alphas_cumprod_prev) + ) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer( + "sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", + to_torch(np.sqrt(1.0 - alphas_cumprod.cpu())), + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod.cpu())) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", + to_torch(np.sqrt(1.0 / alphas_cumprod.cpu() - 1)), + ) + + # ddim sampling parameters + ddim_sigmas, ddim_alphas, ddim_alphas_prev = make_ddim_sampling_parameters( + alphacums=alphas_cumprod.cpu(), + ddim_timesteps=self.ddim_timesteps, + eta=ddim_eta, + verbose=verbose, + ) + self.register_buffer("ddim_sigmas", ddim_sigmas) + self.register_buffer("ddim_alphas", ddim_alphas) + self.register_buffer("ddim_alphas_prev", ddim_alphas_prev) + self.register_buffer("ddim_sqrt_one_minus_alphas", np.sqrt(1.0 - ddim_alphas)) + sigmas_for_original_sampling_steps = ddim_eta * torch.sqrt( + (1 - self.alphas_cumprod_prev) + / (1 - self.alphas_cumprod) + * (1 - self.alphas_cumprod / self.alphas_cumprod_prev) + ) + self.register_buffer( + "ddim_sigmas_for_original_num_steps", sigmas_for_original_sampling_steps + ) + + @torch.no_grad() + def sample( + self, + S, + batch_size, + shape, + conditioning=None, + callback=None, + normals_sequence=None, + img_callback=None, + quantize_x0=False, + eta=0.0, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + verbose=True, + x_T=None, + log_every_t=100, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + # this has to come in the same format as the conditioning, # e.g. as encoded tokens, ... + **kwargs, + ): + if conditioning is not None: + if isinstance(conditioning, dict): + cbs = conditioning[list(conditioning.keys())[0]].shape[0] + if cbs != batch_size: + print( + f"Warning: Got {cbs} conditionings but batch-size is {batch_size}" + ) + else: + if conditioning.shape[0] != batch_size: + print( + f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}" + ) + + self.make_schedule(ddim_num_steps=S, ddim_eta=eta, verbose=verbose) + # sampling + C, H, W = shape + size = (batch_size, C, H, W) + samples, intermediates = self.ddim_sampling( + conditioning, + size, + callback=callback, + img_callback=img_callback, + quantize_denoised=quantize_x0, + mask=mask, + x0=x0, + ddim_use_original_steps=False, + noise_dropout=noise_dropout, + temperature=temperature, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + x_T=x_T, + log_every_t=log_every_t, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return samples, intermediates + + @torch.no_grad() + def ddim_sampling( + self, + cond, + shape, + x_T=None, + ddim_use_original_steps=False, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + log_every_t=100, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + device = self.model.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + if timesteps is None: + timesteps = ( + self.ddpm_num_timesteps + if ddim_use_original_steps + else self.ddim_timesteps + ) + elif timesteps is not None and not ddim_use_original_steps: + subset_end = ( + int( + min(timesteps / self.ddim_timesteps.shape[0], 1) + * self.ddim_timesteps.shape[0] + ) + - 1 + ) + timesteps = self.ddim_timesteps[:subset_end] + + intermediates = {"x_inter": [img], "pred_x0": [img]} + time_range = ( + reversed(range(0, timesteps)) + if ddim_use_original_steps + else np.flip(timesteps) + ) + total_steps = timesteps if ddim_use_original_steps else timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="DDIM Sampler", total=total_steps) + iterator = tqdm(time_range, desc="DDIM Sampler", total=total_steps, leave=False) + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full((b,), step, device=device, dtype=torch.long) + if mask is not None: + assert x0 is not None + img_orig = self.model.q_sample( + x0, ts + ) # TODO deterministic forward pass? + img = ( + img_orig * mask + (1.0 - mask) * img + ) # In the first sampling step, img is pure gaussian noise + + outs = self.p_sample_ddim( + img, + cond, + ts, + index=index, + use_original_steps=ddim_use_original_steps, + quantize_denoised=quantize_denoised, + temperature=temperature, + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + img, pred_x0 = outs + if callback: + callback(i) + if img_callback: + img_callback(pred_x0, i) + + if index % log_every_t == 0 or index == total_steps - 1: + intermediates["x_inter"].append(img) + intermediates["pred_x0"].append(pred_x0) + + return img, intermediates + + @torch.no_grad() + def stochastic_encode(self, x0, t, use_original_steps=False, noise=None): + # fast, but does not allow for exact reconstruction + # t serves as an index to gather the correct alphas + if use_original_steps: + sqrt_alphas_cumprod = self.sqrt_alphas_cumprod + sqrt_one_minus_alphas_cumprod = self.sqrt_one_minus_alphas_cumprod + else: + sqrt_alphas_cumprod = torch.sqrt(self.ddim_alphas) + sqrt_one_minus_alphas_cumprod = self.ddim_sqrt_one_minus_alphas + + if noise is None: + noise = torch.randn_like(x0) + + return ( + extract_into_tensor(sqrt_alphas_cumprod, t, x0.shape) * x0 + + extract_into_tensor(sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise + ) + + @torch.no_grad() + def decode( + self, + x_latent, + cond, + t_start, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_original_steps=False, + ): + + timesteps = ( + np.arange(self.ddpm_num_timesteps) + if use_original_steps + else self.ddim_timesteps + ) + timesteps = timesteps[:t_start] + + time_range = np.flip(timesteps) + total_steps = timesteps.shape[0] + # print(f"Running DDIM Sampling with {total_steps} timesteps") + + # iterator = gr.Progress().tqdm(time_range, desc="Decoding image", total=total_steps) + iterator = tqdm(time_range, desc="Decoding image", total=total_steps) + x_dec = x_latent + + for i, step in enumerate(iterator): + index = total_steps - i - 1 + ts = torch.full( + (x_latent.shape[0],), step, device=x_latent.device, dtype=torch.long + ) + x_dec, _ = self.p_sample_ddim( + x_dec, + cond, + ts, + index=index, + use_original_steps=use_original_steps, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + ) + return x_dec + + @torch.no_grad() + def p_sample_ddim( + self, + x, + c, + t, + index, + repeat_noise=False, + use_original_steps=False, + quantize_denoised=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + ): + b, *_, device = *x.shape, x.device + + if unconditional_conditioning is None or unconditional_guidance_scale == 1.0: + e_t = self.model.apply_model(x, t, c) + else: + x_in = torch.cat([x] * 2) + t_in = torch.cat([t] * 2) + c_in = torch.cat([unconditional_conditioning, c]) + e_t_uncond, e_t = self.model.apply_model(x_in, t_in, c_in).chunk(2) + # When unconditional_guidance_scale == 1: only e_t + # When unconditional_guidance_scale == 0: only unconditional + # When unconditional_guidance_scale > 1: add more unconditional guidance + e_t = e_t_uncond + unconditional_guidance_scale * (e_t - e_t_uncond) + + if score_corrector is not None: + assert self.model.parameterization == "eps" + e_t = score_corrector.modify_score( + self.model, e_t, x, t, c, **corrector_kwargs + ) + + alphas = self.model.alphas_cumprod if use_original_steps else self.ddim_alphas + alphas_prev = ( + self.model.alphas_cumprod_prev + if use_original_steps + else self.ddim_alphas_prev + ) + sqrt_one_minus_alphas = ( + self.model.sqrt_one_minus_alphas_cumprod + if use_original_steps + else self.ddim_sqrt_one_minus_alphas + ) + sigmas = ( + self.model.ddim_sigmas_for_original_num_steps + if use_original_steps + else self.ddim_sigmas + ) + # select parameters corresponding to the currently considered timestep + a_t = torch.full((b, 1, 1, 1), alphas[index], device=device) + a_prev = torch.full((b, 1, 1, 1), alphas_prev[index], device=device) + sigma_t = torch.full((b, 1, 1, 1), sigmas[index], device=device) + sqrt_one_minus_at = torch.full( + (b, 1, 1, 1), sqrt_one_minus_alphas[index], device=device + ) + + # current prediction for x_0 + pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt() + if quantize_denoised: + pred_x0, _, *_ = self.model.first_stage_model.quantize(pred_x0) + # direction pointing to x_t + dir_xt = (1.0 - a_prev - sigma_t**2).sqrt() * e_t + noise = sigma_t * noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + x_prev = a_prev.sqrt() * pred_x0 + dir_xt + noise # TODO + return x_prev, pred_x0 diff --git a/audioldm/latent_diffusion/ddpm.py b/audioldm/latent_diffusion/ddpm.py new file mode 100644 index 0000000000000000000000000000000000000000..ffca031c27d413698adee5a58547b7d0ea4069c3 --- /dev/null +++ b/audioldm/latent_diffusion/ddpm.py @@ -0,0 +1,441 @@ +""" +wild mixture of +https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +https://github.com/openai/improved-diffusion/blob/e94489283bb876ac1477d5dd7709bbbd2d9902ce/improved_diffusion/gaussian_diffusion.py +https://github.com/CompVis/taming-transformers +-- merci +""" +import sys +import os + +import torch +import torch.nn as nn +import numpy as np +from contextlib import contextmanager +from functools import partial +from tqdm import tqdm + +from audioldm.utils import exists, default, count_params, instantiate_from_config +from audioldm.latent_diffusion.ema import LitEma +from audioldm.latent_diffusion.util import ( + make_beta_schedule, + extract_into_tensor, + noise_like, +) +import soundfile as sf +import os + + +__conditioning_keys__ = {"concat": "c_concat", "crossattn": "c_crossattn", "adm": "y"} + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +def uniform_on_device(r1, r2, shape, device): + return (r1 - r2) * torch.rand(*shape, device=device) + r2 + + +class DiffusionWrapper(nn.Module): + def __init__(self, diff_model_config, conditioning_key): + super().__init__() + self.diffusion_model = instantiate_from_config(diff_model_config) + self.conditioning_key = conditioning_key + assert self.conditioning_key in [ + None, + "concat", + "crossattn", + "hybrid", + "adm", + "film", + ] + + def forward( + self, x, t, c_concat: list = None, c_crossattn: list = None, c_film: list = None + ): + x = x.contiguous() + t = t.contiguous() + + if self.conditioning_key is None: + out = self.diffusion_model(x, t) + elif self.conditioning_key == "concat": + xc = torch.cat([x] + c_concat, dim=1) + out = self.diffusion_model(xc, t) + elif self.conditioning_key == "crossattn": + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(x, t, context=cc) + elif self.conditioning_key == "hybrid": + xc = torch.cat([x] + c_concat, dim=1) + cc = torch.cat(c_crossattn, 1) + out = self.diffusion_model(xc, t, context=cc) + elif ( + self.conditioning_key == "film" + ): # The condition is assumed to be a global token, which wil pass through a linear layer and added with the time embedding for the FILM + cc = c_film[0].squeeze(1) # only has one token + out = self.diffusion_model(x, t, y=cc) + elif self.conditioning_key == "adm": + cc = c_crossattn[0] + out = self.diffusion_model(x, t, y=cc) + else: + raise NotImplementedError() + + return out + + +class DDPM(nn.Module): + # classic DDPM with Gaussian diffusion, in image space + def __init__( + self, + unet_config, + timesteps=1000, + beta_schedule="linear", + loss_type="l2", + ckpt_path=None, + ignore_keys=[], + load_only_unet=False, + monitor="val/loss", + use_ema=True, + first_stage_key="image", + latent_t_size=256, + latent_f_size=16, + channels=3, + log_every_t=100, + clip_denoised=True, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + given_betas=None, + original_elbo_weight=0.0, + v_posterior=0.0, # weight for choosing posterior variance as sigma = (1-v) * beta_tilde + v * beta + l_simple_weight=1.0, + conditioning_key=None, + parameterization="eps", # all assuming fixed variance schedules + scheduler_config=None, + use_positional_encodings=False, + learn_logvar=False, + logvar_init=0.0, + ): + super().__init__() + assert parameterization in [ + "eps", + "x0", + ], 'currently only supporting "eps" and "x0"' + self.parameterization = parameterization + self.state = None + # print(f"{self.__class__.__name__}: Running in {self.parameterization}-prediction mode") + self.cond_stage_model = None + self.clip_denoised = clip_denoised + self.log_every_t = log_every_t + self.first_stage_key = first_stage_key + + self.latent_t_size = latent_t_size + self.latent_f_size = latent_f_size + + self.channels = channels + self.use_positional_encodings = use_positional_encodings + self.model = DiffusionWrapper(unet_config, conditioning_key) + count_params(self.model, verbose=True) + self.use_ema = use_ema + if self.use_ema: + self.model_ema = LitEma(self.model) + # print(f"Keeping EMAs of {len(list(self.model_ema.buffers()))}.") + + self.use_scheduler = scheduler_config is not None + if self.use_scheduler: + self.scheduler_config = scheduler_config + + self.v_posterior = v_posterior + self.original_elbo_weight = original_elbo_weight + self.l_simple_weight = l_simple_weight + + if monitor is not None: + self.monitor = monitor + + self.register_schedule( + given_betas=given_betas, + beta_schedule=beta_schedule, + timesteps=timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + + self.loss_type = loss_type + + self.learn_logvar = learn_logvar + self.logvar = torch.full(fill_value=logvar_init, size=(self.num_timesteps,)) + if self.learn_logvar: + self.logvar = nn.Parameter(self.logvar, requires_grad=True) + else: + self.logvar = nn.Parameter(self.logvar, requires_grad=False) + + self.logger_save_dir = None + self.logger_project = None + self.logger_version = None + self.label_indices_total = None + # To avoid the system cannot find metric value for checkpoint + self.metrics_buffer = { + "val/kullback_leibler_divergence_sigmoid": 15.0, + "val/kullback_leibler_divergence_softmax": 10.0, + "val/psnr": 0.0, + "val/ssim": 0.0, + "val/inception_score_mean": 1.0, + "val/inception_score_std": 0.0, + "val/kernel_inception_distance_mean": 0.0, + "val/kernel_inception_distance_std": 0.0, + "val/frechet_inception_distance": 133.0, + "val/frechet_audio_distance": 32.0, + } + self.initial_learning_rate = None + + def get_log_dir(self): + if ( + self.logger_save_dir is None + and self.logger_project is None + and self.logger_version is None + ): + return os.path.join( + self.logger.save_dir, self.logger._project, self.logger.version + ) + else: + return os.path.join( + self.logger_save_dir, self.logger_project, self.logger_version + ) + + def set_log_dir(self, save_dir, project, version): + self.logger_save_dir = save_dir + self.logger_project = project + self.logger_version = version + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + if exists(given_betas): + betas = given_betas + else: + betas = make_beta_schedule( + beta_schedule, + timesteps, + linear_start=linear_start, + linear_end=linear_end, + cosine_s=cosine_s, + ) + alphas = 1.0 - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) + + (timesteps,) = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert ( + alphas_cumprod.shape[0] == self.num_timesteps + ), "alphas have to be defined for each timestep" + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer("betas", to_torch(betas)) + self.register_buffer("alphas_cumprod", to_torch(alphas_cumprod)) + self.register_buffer("alphas_cumprod_prev", to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer("sqrt_alphas_cumprod", to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer( + "sqrt_one_minus_alphas_cumprod", to_torch(np.sqrt(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "log_one_minus_alphas_cumprod", to_torch(np.log(1.0 - alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recip_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod)) + ) + self.register_buffer( + "sqrt_recipm1_alphas_cumprod", to_torch(np.sqrt(1.0 / alphas_cumprod - 1)) + ) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + posterior_variance = (1 - self.v_posterior) * betas * ( + 1.0 - alphas_cumprod_prev + ) / (1.0 - alphas_cumprod) + self.v_posterior * betas + # above: equal to 1. / (1. / (1. - alpha_cumprod_tm1) + alpha_t / beta_t) + self.register_buffer("posterior_variance", to_torch(posterior_variance)) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.register_buffer( + "posterior_log_variance_clipped", + to_torch(np.log(np.maximum(posterior_variance, 1e-20))), + ) + self.register_buffer( + "posterior_mean_coef1", + to_torch(betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod)), + ) + self.register_buffer( + "posterior_mean_coef2", + to_torch( + (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod) + ), + ) + + if self.parameterization == "eps": + lvlb_weights = self.betas**2 / ( + 2 + * self.posterior_variance + * to_torch(alphas) + * (1 - self.alphas_cumprod) + ) + elif self.parameterization == "x0": + lvlb_weights = ( + 0.5 + * np.sqrt(torch.Tensor(alphas_cumprod)) + / (2.0 * 1 - torch.Tensor(alphas_cumprod)) + ) + else: + raise NotImplementedError("mu not supported") + # TODO how to choose this term + lvlb_weights[0] = lvlb_weights[1] + self.register_buffer("lvlb_weights", lvlb_weights, persistent=False) + assert not torch.isnan(self.lvlb_weights).all() + + @contextmanager + def ema_scope(self, context=None): + if self.use_ema: + self.model_ema.store(self.model.parameters()) + self.model_ema.copy_to(self.model) + if context is not None: + # print(f"{context}: Switched to EMA weights") + pass + try: + yield None + finally: + if self.use_ema: + self.model_ema.restore(self.model.parameters()) + if context is not None: + # print(f"{context}: Restored training weights") + pass + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = extract_into_tensor( + self.log_one_minus_alphas_cumprod, t, x_start.shape + ) + return mean, variance, log_variance + + def predict_start_from_noise(self, x_t, t, noise): + return ( + extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + * noise + ) + + def q_posterior(self, x_start, x_t, t): + posterior_mean = ( + extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, x, t, clip_denoised: bool): + model_out = self.model(x, t) + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample(self, x, t, clip_denoised=True, repeat_noise=False): + b, *_, device = *x.shape, x.device + model_mean, _, model_log_variance = self.p_mean_variance( + x=x, t=t, clip_denoised=clip_denoised + ) + noise = noise_like(x.shape, device, repeat_noise) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def p_sample_loop(self, shape, return_intermediates=False): + device = self.betas.device + b = shape[0] + img = torch.randn(shape, device=device) + intermediates = [img] + for i in tqdm( + reversed(range(0, self.num_timesteps)), + desc="Sampling t", + total=self.num_timesteps, + ): + img = self.p_sample( + img, + torch.full((b,), i, device=device, dtype=torch.long), + clip_denoised=self.clip_denoised, + ) + if i % self.log_every_t == 0 or i == self.num_timesteps - 1: + intermediates.append(img) + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample(self, batch_size=16, return_intermediates=False): + shape = (batch_size, channels, self.latent_t_size, self.latent_f_size) + channels = self.channels + return self.p_sample_loop(shape, return_intermediates=return_intermediates) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return ( + extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) + * noise + ) + + def forward(self, x, *args, **kwargs): + t = torch.randint( + 0, self.num_timesteps, (x.shape[0],), device=self.device + ).long() + return self.p_losses(x, t, *args, **kwargs) + + def get_input(self, batch, k): + # fbank, log_magnitudes_stft, label_indices, fname, waveform, clip_label, text = batch + fbank, log_magnitudes_stft, label_indices, fname, waveform, text = batch + ret = {} + + ret["fbank"] = ( + fbank.unsqueeze(1).to(memory_format=torch.contiguous_format).float() + ) + ret["stft"] = log_magnitudes_stft.to( + memory_format=torch.contiguous_format + ).float() + # ret["clip_label"] = clip_label.to(memory_format=torch.contiguous_format).float() + ret["waveform"] = waveform.to(memory_format=torch.contiguous_format).float() + ret["text"] = list(text) + ret["fname"] = fname + + return ret[k] diff --git a/audioldm/latent_diffusion/ema.py b/audioldm/latent_diffusion/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..880ca3d205d9b4d7450e146930a93f2e63c58b70 --- /dev/null +++ b/audioldm/latent_diffusion/ema.py @@ -0,0 +1,82 @@ +import torch +from torch import nn + + +class LitEma(nn.Module): + def __init__(self, model, decay=0.9999, use_num_upates=True): + super().__init__() + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + + self.m_name2s_name = {} + self.register_buffer("decay", torch.tensor(decay, dtype=torch.float32)) + self.register_buffer( + "num_updates", + torch.tensor(0, dtype=torch.int) + if use_num_upates + else torch.tensor(-1, dtype=torch.int), + ) + + for name, p in model.named_parameters(): + if p.requires_grad: + # remove as '.'-character is not allowed in buffers + s_name = name.replace(".", "") + self.m_name2s_name.update({name: s_name}) + self.register_buffer(s_name, p.clone().detach().data) + + self.collected_params = [] + + def forward(self, model): + decay = self.decay + + if self.num_updates >= 0: + self.num_updates += 1 + decay = min(self.decay, (1 + self.num_updates) / (10 + self.num_updates)) + + one_minus_decay = 1.0 - decay + + with torch.no_grad(): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + + for key in m_param: + if m_param[key].requires_grad: + sname = self.m_name2s_name[key] + shadow_params[sname] = shadow_params[sname].type_as(m_param[key]) + shadow_params[sname].sub_( + one_minus_decay * (shadow_params[sname] - m_param[key]) + ) + else: + assert not key in self.m_name2s_name + + def copy_to(self, model): + m_param = dict(model.named_parameters()) + shadow_params = dict(self.named_buffers()) + for key in m_param: + if m_param[key].requires_grad: + m_param[key].data.copy_(shadow_params[self.m_name2s_name[key]].data) + else: + assert not key in self.m_name2s_name + + def store(self, parameters): + """ + Save the current parameters for restoring later. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. + """ + self.collected_params = [param.clone() for param in parameters] + + def restore(self, parameters): + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. + """ + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) diff --git a/audioldm/latent_diffusion/openaimodel.py b/audioldm/latent_diffusion/openaimodel.py new file mode 100644 index 0000000000000000000000000000000000000000..831d7aafb36bba16888e4389153979a6c13639f5 --- /dev/null +++ b/audioldm/latent_diffusion/openaimodel.py @@ -0,0 +1,1069 @@ +from abc import abstractmethod +import math + +import numpy as np +import torch as th +import torch.nn as nn +import torch.nn.functional as F + +from audioldm.latent_diffusion.util import ( + checkpoint, + conv_nd, + linear, + avg_pool_nd, + zero_module, + normalization, + timestep_embedding, +) +from audioldm.latent_diffusion.attention import SpatialTransformer + + +# dummy replace +def convert_module_to_f16(x): + pass + + +def convert_module_to_f32(x): + pass + + +## go +class AttentionPool2d(nn.Module): + """ + Adapted from CLIP: https://github.com/openai/CLIP/blob/main/clip/model.py + """ + + def __init__( + self, + spacial_dim: int, + embed_dim: int, + num_heads_channels: int, + output_dim: int = None, + ): + super().__init__() + self.positional_embedding = nn.Parameter( + th.randn(embed_dim, spacial_dim**2 + 1) / embed_dim**0.5 + ) + self.qkv_proj = conv_nd(1, embed_dim, 3 * embed_dim, 1) + self.c_proj = conv_nd(1, embed_dim, output_dim or embed_dim, 1) + self.num_heads = embed_dim // num_heads_channels + self.attention = QKVAttention(self.num_heads) + + def forward(self, x): + b, c, *_spatial = x.shape + x = x.reshape(b, c, -1).contiguous() # NC(HW) + x = th.cat([x.mean(dim=-1, keepdim=True), x], dim=-1) # NC(HW+1) + x = x + self.positional_embedding[None, :, :].to(x.dtype) # NC(HW+1) + x = self.qkv_proj(x) + x = self.attention(x) + x = self.c_proj(x) + return x[:, :, 0] + + +class TimestepBlock(nn.Module): + """ + Any module where forward() takes timestep embeddings as a second argument. + """ + + @abstractmethod + def forward(self, x, emb): + """ + Apply the module to `x` given `emb` timestep embeddings. + """ + + +class TimestepEmbedSequential(nn.Sequential, TimestepBlock): + """ + A sequential module that passes timestep embeddings to the children that + support it as an extra input. + """ + + def forward(self, x, emb, context=None): + for layer in self: + if isinstance(layer, TimestepBlock): + x = layer(x, emb) + elif isinstance(layer, SpatialTransformer): + x = layer(x, context) + else: + x = layer(x) + return x + + +class Upsample(nn.Module): + """ + An upsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + upsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + if use_conv: + self.conv = conv_nd( + dims, self.channels, self.out_channels, 3, padding=padding + ) + + def forward(self, x): + assert x.shape[1] == self.channels + if self.dims == 3: + x = F.interpolate( + x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + ) + else: + x = F.interpolate(x, scale_factor=2, mode="nearest") + if self.use_conv: + x = self.conv(x) + return x + + +class TransposedUpsample(nn.Module): + "Learned 2x upsampling without padding" + + def __init__(self, channels, out_channels=None, ks=5): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + + self.up = nn.ConvTranspose2d( + self.channels, self.out_channels, kernel_size=ks, stride=2 + ) + + def forward(self, x): + return self.up(x) + + +class Downsample(nn.Module): + """ + A downsampling layer with an optional convolution. + :param channels: channels in the inputs and outputs. + :param use_conv: a bool determining if a convolution is applied. + :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then + downsampling occurs in the inner-two dimensions. + """ + + def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.dims = dims + stride = 2 if dims != 3 else (1, 2, 2) + if use_conv: + self.op = conv_nd( + dims, + self.channels, + self.out_channels, + 3, + stride=stride, + padding=padding, + ) + else: + assert self.channels == self.out_channels + self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride) + + def forward(self, x): + assert x.shape[1] == self.channels + return self.op(x) + + +class ResBlock(TimestepBlock): + """ + A residual block that can optionally change the number of channels. + :param channels: the number of input channels. + :param emb_channels: the number of timestep embedding channels. + :param dropout: the rate of dropout. + :param out_channels: if specified, the number of out channels. + :param use_conv: if True and out_channels is specified, use a spatial + convolution instead of a smaller 1x1 convolution to change the + channels in the skip connection. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param use_checkpoint: if True, use gradient checkpointing on this module. + :param up: if True, use this block for upsampling. + :param down: if True, use this block for downsampling. + """ + + def __init__( + self, + channels, + emb_channels, + dropout, + out_channels=None, + use_conv=False, + use_scale_shift_norm=False, + dims=2, + use_checkpoint=False, + up=False, + down=False, + ): + super().__init__() + self.channels = channels + self.emb_channels = emb_channels + self.dropout = dropout + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_checkpoint = use_checkpoint + self.use_scale_shift_norm = use_scale_shift_norm + + self.in_layers = nn.Sequential( + normalization(channels), + nn.SiLU(), + conv_nd(dims, channels, self.out_channels, 3, padding=1), + ) + + self.updown = up or down + + if up: + self.h_upd = Upsample(channels, False, dims) + self.x_upd = Upsample(channels, False, dims) + elif down: + self.h_upd = Downsample(channels, False, dims) + self.x_upd = Downsample(channels, False, dims) + else: + self.h_upd = self.x_upd = nn.Identity() + + self.emb_layers = nn.Sequential( + nn.SiLU(), + linear( + emb_channels, + 2 * self.out_channels if use_scale_shift_norm else self.out_channels, + ), + ) + self.out_layers = nn.Sequential( + normalization(self.out_channels), + nn.SiLU(), + nn.Dropout(p=dropout), + zero_module( + conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1) + ), + ) + + if self.out_channels == channels: + self.skip_connection = nn.Identity() + elif use_conv: + self.skip_connection = conv_nd( + dims, channels, self.out_channels, 3, padding=1 + ) + else: + self.skip_connection = conv_nd(dims, channels, self.out_channels, 1) + + def forward(self, x, emb): + """ + Apply the block to a Tensor, conditioned on a timestep embedding. + :param x: an [N x C x ...] Tensor of features. + :param emb: an [N x emb_channels] Tensor of timestep embeddings. + :return: an [N x C x ...] Tensor of outputs. + """ + return checkpoint( + self._forward, (x, emb), self.parameters(), self.use_checkpoint + ) + + def _forward(self, x, emb): + if self.updown: + in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1] + h = in_rest(x) + h = self.h_upd(h) + x = self.x_upd(x) + h = in_conv(h) + else: + h = self.in_layers(x) + emb_out = self.emb_layers(emb).type(h.dtype) + while len(emb_out.shape) < len(h.shape): + emb_out = emb_out[..., None] + if self.use_scale_shift_norm: + out_norm, out_rest = self.out_layers[0], self.out_layers[1:] + scale, shift = th.chunk(emb_out, 2, dim=1) + h = out_norm(h) * (1 + scale) + shift + h = out_rest(h) + else: + h = h + emb_out + h = self.out_layers(h) + return self.skip_connection(x) + h + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. + Originally ported from here, but adapted to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + """ + + def __init__( + self, + channels, + num_heads=1, + num_head_channels=-1, + use_checkpoint=False, + use_new_attention_order=False, + ): + super().__init__() + self.channels = channels + if num_head_channels == -1: + self.num_heads = num_heads + else: + assert ( + channels % num_head_channels == 0 + ), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}" + self.num_heads = channels // num_head_channels + self.use_checkpoint = use_checkpoint + self.norm = normalization(channels) + self.qkv = conv_nd(1, channels, channels * 3, 1) + if use_new_attention_order: + # split qkv before split heads + self.attention = QKVAttention(self.num_heads) + else: + # split heads before split qkv + self.attention = QKVAttentionLegacy(self.num_heads) + + self.proj_out = zero_module(conv_nd(1, channels, channels, 1)) + + def forward(self, x): + return checkpoint( + self._forward, (x,), self.parameters(), True + ) # TODO: check checkpoint usage, is True # TODO: fix the .half call!!! + # return pt_checkpoint(self._forward, x) # pytorch + + def _forward(self, x): + b, c, *spatial = x.shape + x = x.reshape(b, c, -1).contiguous() + qkv = self.qkv(self.norm(x)).contiguous() + h = self.attention(qkv).contiguous() + h = self.proj_out(h).contiguous() + return (x + h).reshape(b, c, *spatial).contiguous() + + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial**2) * c + model.total_ops += th.DoubleTensor([matmul_ops]) + + +class QKVAttentionLegacy(nn.Module): + """ + A module which performs QKV attention. Matches legacy QKVAttention + input/ouput heads shaping + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = ( + qkv.reshape(bs * self.n_heads, ch * 3, length).contiguous().split(ch, dim=1) + ) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", q * scale, k * scale + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum("bts,bcs->bct", weight, v) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class QKVAttention(nn.Module): + """ + A module which performs QKV attention and splits in a different order. + """ + + def __init__(self, n_heads): + super().__init__() + self.n_heads = n_heads + + def forward(self, qkv): + """ + Apply QKV attention. + :param qkv: an [N x (3 * H * C) x T] tensor of Qs, Ks, and Vs. + :return: an [N x (H * C) x T] tensor after attention. + """ + bs, width, length = qkv.shape + assert width % (3 * self.n_heads) == 0 + ch = width // (3 * self.n_heads) + q, k, v = qkv.chunk(3, dim=1) + scale = 1 / math.sqrt(math.sqrt(ch)) + weight = th.einsum( + "bct,bcs->bts", + (q * scale).view(bs * self.n_heads, ch, length), + (k * scale).view(bs * self.n_heads, ch, length), + ) # More stable with f16 than dividing afterwards + weight = th.softmax(weight.float(), dim=-1).type(weight.dtype) + a = th.einsum( + "bts,bcs->bct", + weight, + v.reshape(bs * self.n_heads, ch, length).contiguous(), + ) + return a.reshape(bs, -1, length).contiguous() + + @staticmethod + def count_flops(model, _x, y): + return count_flops_attn(model, _x, y) + + +class UNetModel(nn.Module): + """ + The full UNet model with attention and timestep embedding. + :param in_channels: channels in the input Tensor. + :param model_channels: base channel count for the model. + :param out_channels: channels in the output Tensor. + :param num_res_blocks: number of residual blocks per downsample. + :param attention_resolutions: a collection of downsample rates at which + attention will take place. May be a set, list, or tuple. + For example, if this contains 4, then at 4x downsampling, attention + will be used. + :param dropout: the dropout probability. + :param channel_mult: channel multiplier for each level of the UNet. + :param conv_resample: if True, use learned convolutions for upsampling and + downsampling. + :param dims: determines if the signal is 1D, 2D, or 3D. + :param num_classes: if specified (as an int), then this model will be + class-conditional with `num_classes` classes. + :param use_checkpoint: use gradient checkpointing to reduce memory usage. + :param num_heads: the number of attention heads in each attention layer. + :param num_heads_channels: if specified, ignore num_heads and instead use + a fixed channel width per attention head. + :param num_heads_upsample: works with num_heads to set a different number + of heads for upsampling. Deprecated. + :param use_scale_shift_norm: use a FiLM-like conditioning mechanism. + :param resblock_updown: use residual blocks for up/downsampling. + :param use_new_attention_order: use a different attention pattern for potentially + increased efficiency. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + num_classes=None, + extra_film_condition_dim=None, + use_checkpoint=False, + use_fp16=False, + num_heads=-1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + extra_film_use_concat=False, # If true, concatenate extrafilm condition with time embedding, else addition + resblock_updown=False, + use_new_attention_order=False, + use_spatial_transformer=False, # custom transformer support + transformer_depth=1, # custom transformer support + context_dim=None, # custom transformer support + n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model + legacy=True, + ): + super().__init__() + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + if num_heads == -1: + assert ( + num_head_channels != -1 + ), "Either num_heads or num_head_channels has to be set" + + if num_head_channels == -1: + assert ( + num_heads != -1 + ), "Either num_heads or num_head_channels has to be set" + + self.image_size = image_size + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.num_classes = num_classes + self.extra_film_condition_dim = extra_film_condition_dim + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + self.predict_codebook_ids = n_embed is not None + self.extra_film_use_concat = extra_film_use_concat + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + assert not ( + self.num_classes is not None and self.extra_film_condition_dim is not None + ), "As for the condition of theh UNet model, you can only set using class label or an extra embedding vector (such as from CLAP). You cannot set both num_classes and extra_film_condition_dim." + + if self.num_classes is not None: + self.label_emb = nn.Embedding(num_classes, time_embed_dim) + + self.use_extra_film_by_concat = ( + self.extra_film_condition_dim is not None and self.extra_film_use_concat + ) + self.use_extra_film_by_addition = ( + self.extra_film_condition_dim is not None and not self.extra_film_use_concat + ) + + if self.extra_film_condition_dim is not None: + self.film_emb = nn.Linear(self.extra_film_condition_dim, time_embed_dim) + # print("+ Use extra condition on UNet channel using Film. Extra condition dimension is %s. " % self.extra_film_condition_dim) + # if(self.use_extra_film_by_concat): + # print("\t By concatenation with time embedding") + # elif(self.use_extra_film_by_concat): + # print("\t By addition with time embedding") + + if use_spatial_transformer and ( + self.use_extra_film_by_concat or self.use_extra_film_by_addition + ): + # print("+ Spatial transformer will only be used as self-attention. Because you have choose to use film as your global condition.") + spatial_transformer_no_context = True + else: + spatial_transformer_no_context = False + + if use_spatial_transformer and not spatial_transformer_no_context: + assert ( + context_dim is not None + ), "Fool!! You forgot to include the dimension of your cross-attention conditioning..." + + if context_dim is not None and not spatial_transformer_no_context: + assert ( + use_spatial_transformer + ), "Fool!! You forgot to use the spatial transformer for your cross-attention conditioning..." + from omegaconf.listconfig import ListConfig + + if type(context_dim) == ListConfig: + context_dim = list(context_dim) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ch // num_heads if use_spatial_transformer else num_head_channels + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ), + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + + self.output_blocks = nn.ModuleList([]) + for level, mult in list(enumerate(channel_mult))[::-1]: + for i in range(num_res_blocks + 1): + ich = input_block_chans.pop() + layers = [ + ResBlock( + ch + ich, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=model_channels * mult, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = model_channels * mult + if ds in attention_resolutions: + if num_head_channels == -1: + dim_head = ch // num_heads + else: + num_heads = ch // num_head_channels + dim_head = num_head_channels + if legacy: + # num_heads = 1 + dim_head = ( + ch // num_heads + if use_spatial_transformer + else num_head_channels + ) + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads_upsample, + num_head_channels=dim_head, + use_new_attention_order=use_new_attention_order, + ) + if not use_spatial_transformer + else SpatialTransformer( + ch, + num_heads, + dim_head, + depth=transformer_depth, + context_dim=context_dim, + no_context=spatial_transformer_no_context, + ) + ) + if level and i == num_res_blocks: + out_ch = ch + layers.append( + ResBlock( + ch, + time_embed_dim + if (not self.use_extra_film_by_concat) + else time_embed_dim * 2, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + up=True, + ) + if resblock_updown + else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch) + ) + ds //= 2 + self.output_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)), + ) + if self.predict_codebook_ids: + self.id_predictor = nn.Sequential( + normalization(ch), + conv_nd(dims, model_channels, n_embed, 1), + # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits + ) + + self.shape_reported = False + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + self.output_blocks.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + self.output_blocks.apply(convert_module_to_f32) + + def forward(self, x, timesteps=None, context=None, y=None, **kwargs): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :param context: conditioning plugged in via crossattn + :param y: an [N] Tensor of labels, if class-conditional. an [N, extra_film_condition_dim] Tensor if film-embed conditional + :return: an [N x C x ...] Tensor of outputs. + """ + if not self.shape_reported: + # print("The shape of UNet input is", x.size()) + self.shape_reported = True + + assert (y is not None) == ( + self.num_classes is not None or self.extra_film_condition_dim is not None + ), "must specify y if and only if the model is class-conditional or film embedding conditional" + hs = [] + t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False) + emb = self.time_embed(t_emb) + + if self.num_classes is not None: + assert y.shape == (x.shape[0],) + emb = emb + self.label_emb(y) + + if self.use_extra_film_by_addition: + emb = emb + self.film_emb(y) + elif self.use_extra_film_by_concat: + emb = th.cat([emb, self.film_emb(y)], dim=-1) + + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb, context) + hs.append(h) + h = self.middle_block(h, emb, context) + for module in self.output_blocks: + h = th.cat([h, hs.pop()], dim=1) + h = module(h, emb, context) + h = h.type(x.dtype) + if self.predict_codebook_ids: + return self.id_predictor(h) + else: + return self.out(h) + + +class EncoderUNetModel(nn.Module): + """ + The half UNet model with attention and timestep embedding. + For usage, see UNet. + """ + + def __init__( + self, + image_size, + in_channels, + model_channels, + out_channels, + num_res_blocks, + attention_resolutions, + dropout=0, + channel_mult=(1, 2, 4, 8), + conv_resample=True, + dims=2, + use_checkpoint=False, + use_fp16=False, + num_heads=1, + num_head_channels=-1, + num_heads_upsample=-1, + use_scale_shift_norm=False, + resblock_updown=False, + use_new_attention_order=False, + pool="adaptive", + *args, + **kwargs, + ): + super().__init__() + + if num_heads_upsample == -1: + num_heads_upsample = num_heads + + self.in_channels = in_channels + self.model_channels = model_channels + self.out_channels = out_channels + self.num_res_blocks = num_res_blocks + self.attention_resolutions = attention_resolutions + self.dropout = dropout + self.channel_mult = channel_mult + self.conv_resample = conv_resample + self.use_checkpoint = use_checkpoint + self.dtype = th.float16 if use_fp16 else th.float32 + self.num_heads = num_heads + self.num_head_channels = num_head_channels + self.num_heads_upsample = num_heads_upsample + + time_embed_dim = model_channels * 4 + self.time_embed = nn.Sequential( + linear(model_channels, time_embed_dim), + nn.SiLU(), + linear(time_embed_dim, time_embed_dim), + ) + + self.input_blocks = nn.ModuleList( + [ + TimestepEmbedSequential( + conv_nd(dims, in_channels, model_channels, 3, padding=1) + ) + ] + ) + self._feature_size = model_channels + input_block_chans = [model_channels] + ch = model_channels + ds = 1 + for level, mult in enumerate(channel_mult): + for _ in range(num_res_blocks): + layers = [ + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=mult * model_channels, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ) + ] + ch = mult * model_channels + if ds in attention_resolutions: + layers.append( + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ) + ) + self.input_blocks.append(TimestepEmbedSequential(*layers)) + self._feature_size += ch + input_block_chans.append(ch) + if level != len(channel_mult) - 1: + out_ch = ch + self.input_blocks.append( + TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + out_channels=out_ch, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + down=True, + ) + if resblock_updown + else Downsample( + ch, conv_resample, dims=dims, out_channels=out_ch + ) + ) + ) + ch = out_ch + input_block_chans.append(ch) + ds *= 2 + self._feature_size += ch + + self.middle_block = TimestepEmbedSequential( + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + AttentionBlock( + ch, + use_checkpoint=use_checkpoint, + num_heads=num_heads, + num_head_channels=num_head_channels, + use_new_attention_order=use_new_attention_order, + ), + ResBlock( + ch, + time_embed_dim, + dropout, + dims=dims, + use_checkpoint=use_checkpoint, + use_scale_shift_norm=use_scale_shift_norm, + ), + ) + self._feature_size += ch + self.pool = pool + if pool == "adaptive": + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + nn.AdaptiveAvgPool2d((1, 1)), + zero_module(conv_nd(dims, ch, out_channels, 1)), + nn.Flatten(), + ) + elif pool == "attention": + assert num_head_channels != -1 + self.out = nn.Sequential( + normalization(ch), + nn.SiLU(), + AttentionPool2d( + (image_size // ds), ch, num_head_channels, out_channels + ), + ) + elif pool == "spatial": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + nn.ReLU(), + nn.Linear(2048, self.out_channels), + ) + elif pool == "spatial_v2": + self.out = nn.Sequential( + nn.Linear(self._feature_size, 2048), + normalization(2048), + nn.SiLU(), + nn.Linear(2048, self.out_channels), + ) + else: + raise NotImplementedError(f"Unexpected {pool} pooling") + + def convert_to_fp16(self): + """ + Convert the torso of the model to float16. + """ + self.input_blocks.apply(convert_module_to_f16) + self.middle_block.apply(convert_module_to_f16) + + def convert_to_fp32(self): + """ + Convert the torso of the model to float32. + """ + self.input_blocks.apply(convert_module_to_f32) + self.middle_block.apply(convert_module_to_f32) + + def forward(self, x, timesteps): + """ + Apply the model to an input batch. + :param x: an [N x C x ...] Tensor of inputs. + :param timesteps: a 1-D batch of timesteps. + :return: an [N x K] Tensor of outputs. + """ + emb = self.time_embed(timestep_embedding(timesteps, self.model_channels)) + + results = [] + h = x.type(self.dtype) + for module in self.input_blocks: + h = module(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = self.middle_block(h, emb) + if self.pool.startswith("spatial"): + results.append(h.type(x.dtype).mean(dim=(2, 3))) + h = th.cat(results, axis=-1) + return self.out(h) + else: + h = h.type(x.dtype) + return self.out(h) diff --git a/audioldm/latent_diffusion/util.py b/audioldm/latent_diffusion/util.py new file mode 100644 index 0000000000000000000000000000000000000000..8b289f6aa7f22a070870d8a706f944dc8547e936 --- /dev/null +++ b/audioldm/latent_diffusion/util.py @@ -0,0 +1,295 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch +import torch.nn as nn +import numpy as np +from einops import repeat + +from audioldm.utils import instantiate_from_config + + +def make_beta_schedule( + schedule, n_timestep, linear_start=1e-4, linear_end=2e-2, cosine_s=8e-3 +): + if schedule == "linear": + betas = ( + torch.linspace( + linear_start**0.5, linear_end**0.5, n_timestep, dtype=torch.float64 + ) + ** 2 + ) + + elif schedule == "cosine": + timesteps = ( + torch.arange(n_timestep + 1, dtype=torch.float64) / n_timestep + cosine_s + ) + alphas = timesteps / (1 + cosine_s) * np.pi / 2 + alphas = torch.cos(alphas).pow(2) + alphas = alphas / alphas[0] + betas = 1 - alphas[1:] / alphas[:-1] + betas = np.clip(betas, a_min=0, a_max=0.999) + + elif schedule == "sqrt_linear": + betas = torch.linspace( + linear_start, linear_end, n_timestep, dtype=torch.float64 + ) + elif schedule == "sqrt": + betas = ( + torch.linspace(linear_start, linear_end, n_timestep, dtype=torch.float64) + ** 0.5 + ) + else: + raise ValueError(f"schedule '{schedule}' unknown.") + return betas.numpy() + + +def make_ddim_timesteps( + ddim_discr_method, num_ddim_timesteps, num_ddpm_timesteps, verbose=True +): + if ddim_discr_method == "uniform": + c = num_ddpm_timesteps // num_ddim_timesteps + ddim_timesteps = np.asarray(list(range(0, num_ddpm_timesteps, c))) + elif ddim_discr_method == "quad": + ddim_timesteps = ( + (np.linspace(0, np.sqrt(num_ddpm_timesteps * 0.8), num_ddim_timesteps)) ** 2 + ).astype(int) + else: + raise NotImplementedError( + f'There is no ddim discretization method called "{ddim_discr_method}"' + ) + + # assert ddim_timesteps.shape[0] == num_ddim_timesteps + # add one to get the final alpha values right (the ones from first scale to data during sampling) + steps_out = ddim_timesteps + 1 + if verbose: + print(f"Selected timesteps for ddim sampler: {steps_out}") + return steps_out + + +def make_ddim_sampling_parameters(alphacums, ddim_timesteps, eta, verbose=True): + # select alphas for computing the variance schedule + alphas = alphacums[ddim_timesteps] + alphas_prev = np.asarray([alphacums[0]] + alphacums[ddim_timesteps[:-1]].tolist()) + + # according the the formula provided in https://arxiv.org/abs/2010.02502 + sigmas = eta * np.sqrt( + (1 - alphas_prev) / (1 - alphas) * (1 - alphas / alphas_prev) + ) + if verbose: + print( + f"Selected alphas for ddim sampler: a_t: {alphas}; a_(t-1): {alphas_prev}" + ) + print( + f"For the chosen value of eta, which is {eta}, " + f"this results in the following sigma_t schedule for ddim sampler {sigmas}" + ) + return sigmas, alphas, alphas_prev + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t).contiguous() + return out.reshape(b, *((1,) * (len(x_shape) - 1))).contiguous() + + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) + * torch.arange(start=0, end=half, dtype=torch.float32) + / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat( + [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 + ) + else: + embedding = repeat(timesteps, "b -> b d", d=dim) + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +class HybridConditioner(nn.Module): + def __init__(self, c_concat_config, c_crossattn_config): + super().__init__() + self.concat_conditioner = instantiate_from_config(c_concat_config) + self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + + def forward(self, c_concat, c_crossattn): + c_concat = self.concat_conditioner(c_concat) + c_crossattn = self.crossattn_conditioner(c_crossattn) + return {"c_concat": [c_concat], "c_crossattn": [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat( + shape[0], *((1,) * (len(shape) - 1)) + ) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() diff --git a/audioldm/ldm.py b/audioldm/ldm.py new file mode 100644 index 0000000000000000000000000000000000000000..e0179fd5a506052ac9db22bd37f3db6b910aded5 --- /dev/null +++ b/audioldm/ldm.py @@ -0,0 +1,818 @@ +import os + +import torch +import numpy as np +from tqdm import tqdm +from audioldm.utils import default, instantiate_from_config, save_wave +from audioldm.latent_diffusion.ddpm import DDPM +from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution +from audioldm.latent_diffusion.util import noise_like +from audioldm.latent_diffusion.ddim import DDIMSampler +import os + + +def disabled_train(self, mode=True): + """Overwrite model.train with this function to make sure train/eval mode + does not change anymore.""" + return self + + +class LatentDiffusion(DDPM): + """main class""" + + def __init__( + self, + device="cuda", + first_stage_config=None, + cond_stage_config=None, + num_timesteps_cond=None, + cond_stage_key="image", + cond_stage_trainable=False, + concat_mode=True, + cond_stage_forward=None, + conditioning_key=None, + scale_factor=1.0, + scale_by_std=False, + base_learning_rate=None, + *args, + **kwargs, + ): + self.device = device + self.learning_rate = base_learning_rate + self.num_timesteps_cond = default(num_timesteps_cond, 1) + self.scale_by_std = scale_by_std + assert self.num_timesteps_cond <= kwargs["timesteps"] + # for backwards compatibility after implementation of DiffusionWrapper + if conditioning_key is None: + conditioning_key = "concat" if concat_mode else "crossattn" + if cond_stage_config == "__is_unconditional__": + conditioning_key = None + ckpt_path = kwargs.pop("ckpt_path", None) + ignore_keys = kwargs.pop("ignore_keys", []) + super().__init__(conditioning_key=conditioning_key, *args, **kwargs) + self.concat_mode = concat_mode + self.cond_stage_trainable = cond_stage_trainable + self.cond_stage_key = cond_stage_key + self.cond_stage_key_orig = cond_stage_key + try: + self.num_downs = len(first_stage_config.params.ddconfig.ch_mult) - 1 + except: + self.num_downs = 0 + if not scale_by_std: + self.scale_factor = scale_factor + else: + self.register_buffer("scale_factor", torch.tensor(scale_factor)) + self.instantiate_first_stage(first_stage_config) + self.instantiate_cond_stage(cond_stage_config) + self.cond_stage_forward = cond_stage_forward + self.clip_denoised = False + + def make_cond_schedule( + self, + ): + self.cond_ids = torch.full( + size=(self.num_timesteps,), + fill_value=self.num_timesteps - 1, + dtype=torch.long, + ) + ids = torch.round( + torch.linspace(0, self.num_timesteps - 1, self.num_timesteps_cond) + ).long() + self.cond_ids[: self.num_timesteps_cond] = ids + + def register_schedule( + self, + given_betas=None, + beta_schedule="linear", + timesteps=1000, + linear_start=1e-4, + linear_end=2e-2, + cosine_s=8e-3, + ): + super().register_schedule( + given_betas, beta_schedule, timesteps, linear_start, linear_end, cosine_s + ) + + self.shorten_cond_schedule = self.num_timesteps_cond > 1 + if self.shorten_cond_schedule: + self.make_cond_schedule() + + def instantiate_first_stage(self, config): + model = instantiate_from_config(config) + self.first_stage_model = model.eval() + self.first_stage_model.train = disabled_train + for param in self.first_stage_model.parameters(): + param.requires_grad = False + + def instantiate_cond_stage(self, config): + if not self.cond_stage_trainable: + if config == "__is_first_stage__": + print("Using first stage also as cond stage.") + self.cond_stage_model = self.first_stage_model + elif config == "__is_unconditional__": + print(f"Training {self.__class__.__name__} as an unconditional model.") + self.cond_stage_model = None + # self.be_unconditional = True + else: + model = instantiate_from_config(config) + self.cond_stage_model = model.eval() + self.cond_stage_model.train = disabled_train + for param in self.cond_stage_model.parameters(): + param.requires_grad = False + else: + assert config != "__is_first_stage__" + assert config != "__is_unconditional__" + model = instantiate_from_config(config) + self.cond_stage_model = model + self.cond_stage_model = self.cond_stage_model.to(self.device) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z + + def get_learned_conditioning(self, c): + if self.cond_stage_forward is None: + if hasattr(self.cond_stage_model, "encode") and callable( + self.cond_stage_model.encode + ): + c = self.cond_stage_model.encode(c) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + else: + # Text input is list + if type(c) == list and len(c) == 1: + c = self.cond_stage_model([c[0], c[0]]) + c = c[0:1] + else: + c = self.cond_stage_model(c) + else: + assert hasattr(self.cond_stage_model, self.cond_stage_forward) + c = getattr(self.cond_stage_model, self.cond_stage_forward)(c) + return c + + @torch.no_grad() + def get_input( + self, + batch, + k, + return_first_stage_encode=True, + return_first_stage_outputs=False, + force_c_encode=False, + cond_key=None, + return_original_cond=False, + bs=None, + ): + x = super().get_input(batch, k) + + if bs is not None: + x = x[:bs] + + x = x.to(self.device) + + if return_first_stage_encode: + encoder_posterior = self.encode_first_stage(x) + z = self.get_first_stage_encoding(encoder_posterior).detach() + else: + z = None + + if self.model.conditioning_key is not None: + if cond_key is None: + cond_key = self.cond_stage_key + if cond_key != self.first_stage_key: + if cond_key in ["caption", "coordinates_bbox"]: + xc = batch[cond_key] + elif cond_key == "class_label": + xc = batch + else: + # [bs, 1, 527] + xc = super().get_input(batch, cond_key) + if type(xc) == torch.Tensor: + xc = xc.to(self.device) + else: + xc = x + if not self.cond_stage_trainable or force_c_encode: + if isinstance(xc, dict) or isinstance(xc, list): + c = self.get_learned_conditioning(xc) + else: + c = self.get_learned_conditioning(xc.to(self.device)) + else: + c = xc + + if bs is not None: + c = c[:bs] + + else: + c = None + xc = None + if self.use_positional_encodings: + pos_x, pos_y = self.compute_latent_shifts(batch) + c = {"pos_x": pos_x, "pos_y": pos_y} + out = [z, c] + if return_first_stage_outputs: + xrec = self.decode_first_stage(z) + out.extend([x, xrec]) + if return_original_cond: + out.append(xc) + return out + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.first_stage_model.decode(z) + + def mel_spectrogram_to_waveform(self, mel): + # Mel: [bs, 1, t-steps, fbins] + if len(mel.size()) == 4: + mel = mel.squeeze(1) + mel = mel.permute(0, 2, 1) + waveform = self.first_stage_model.vocoder(mel) + waveform = waveform.cpu().detach().numpy() + return waveform + + @torch.no_grad() + def encode_first_stage(self, x): + return self.first_stage_model.encode(x) + + def apply_model(self, x_noisy, t, cond, return_ids=False): + + if isinstance(cond, dict): + # hybrid case, cond is exptected to be a dict + pass + else: + if not isinstance(cond, list): + cond = [cond] + if self.model.conditioning_key == "concat": + key = "c_concat" + elif self.model.conditioning_key == "crossattn": + key = "c_crossattn" + else: + key = "c_film" + + cond = {key: cond} + + x_recon = self.model(x_noisy, t, **cond) + + if isinstance(x_recon, tuple) and not return_ids: + return x_recon[0] + else: + return x_recon + + def p_mean_variance( + self, + x, + c, + t, + clip_denoised: bool, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + score_corrector=None, + corrector_kwargs=None, + ): + t_in = t + model_out = self.apply_model(x, t_in, c, return_ids=return_codebook_ids) + + if score_corrector is not None: + assert self.parameterization == "eps" + model_out = score_corrector.modify_score( + self, model_out, x, t, c, **corrector_kwargs + ) + + if return_codebook_ids: + model_out, logits = model_out + + if self.parameterization == "eps": + x_recon = self.predict_start_from_noise(x, t=t, noise=model_out) + elif self.parameterization == "x0": + x_recon = model_out + else: + raise NotImplementedError() + + if clip_denoised: + x_recon.clamp_(-1.0, 1.0) + if quantize_denoised: + x_recon, _, [_, _, indices] = self.first_stage_model.quantize(x_recon) + model_mean, posterior_variance, posterior_log_variance = self.q_posterior( + x_start=x_recon, x_t=x, t=t + ) + if return_codebook_ids: + return model_mean, posterior_variance, posterior_log_variance, logits + elif return_x0: + return model_mean, posterior_variance, posterior_log_variance, x_recon + else: + return model_mean, posterior_variance, posterior_log_variance + + @torch.no_grad() + def p_sample( + self, + x, + c, + t, + clip_denoised=False, + repeat_noise=False, + return_codebook_ids=False, + quantize_denoised=False, + return_x0=False, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + ): + b, *_, device = *x.shape, x.device + outputs = self.p_mean_variance( + x=x, + c=c, + t=t, + clip_denoised=clip_denoised, + return_codebook_ids=return_codebook_ids, + quantize_denoised=quantize_denoised, + return_x0=return_x0, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if return_codebook_ids: + raise DeprecationWarning("Support dropped.") + model_mean, _, model_log_variance, logits = outputs + elif return_x0: + model_mean, _, model_log_variance, x0 = outputs + else: + model_mean, _, model_log_variance = outputs + + noise = noise_like(x.shape, device, repeat_noise) * temperature + if noise_dropout > 0.0: + noise = torch.nn.functional.dropout(noise, p=noise_dropout) + # no noise when t == 0 + nonzero_mask = ( + (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1))).contiguous() + ) + + if return_codebook_ids: + return model_mean + nonzero_mask * ( + 0.5 * model_log_variance + ).exp() * noise, logits.argmax(dim=1) + if return_x0: + return ( + model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, + x0, + ) + else: + return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise + + @torch.no_grad() + def progressive_denoising( + self, + cond, + shape, + verbose=True, + callback=None, + quantize_denoised=False, + img_callback=None, + mask=None, + x0=None, + temperature=1.0, + noise_dropout=0.0, + score_corrector=None, + corrector_kwargs=None, + batch_size=None, + x_T=None, + start_T=None, + log_every_t=None, + ): + if not log_every_t: + log_every_t = self.log_every_t + timesteps = self.num_timesteps + if batch_size is not None: + b = batch_size if batch_size is not None else shape[0] + shape = [batch_size] + list(shape) + else: + b = batch_size = shape[0] + if x_T is None: + img = torch.randn(shape, device=self.device) + else: + img = x_T + intermediates = [] + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm( + reversed(range(0, timesteps)), + desc="Progressive Generation", + total=timesteps, + ) + if verbose + else reversed(range(0, timesteps)) + ) + if type(temperature) == float: + temperature = [temperature] * timesteps + + for i in iterator: + ts = torch.full((b,), i, device=self.device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img, x0_partial = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + return_x0=True, + temperature=temperature[i], + noise_dropout=noise_dropout, + score_corrector=score_corrector, + corrector_kwargs=corrector_kwargs, + ) + if mask is not None: + assert x0 is not None + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(x0_partial) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + return img, intermediates + + @torch.no_grad() + def p_sample_loop( + self, + cond, + shape, + return_intermediates=False, + x_T=None, + verbose=True, + callback=None, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + img_callback=None, + start_T=None, + log_every_t=None, + ): + + if not log_every_t: + log_every_t = self.log_every_t + device = self.betas.device + b = shape[0] + if x_T is None: + img = torch.randn(shape, device=device) + else: + img = x_T + + intermediates = [img] + if timesteps is None: + timesteps = self.num_timesteps + + if start_T is not None: + timesteps = min(timesteps, start_T) + iterator = ( + tqdm(reversed(range(0, timesteps)), desc="Sampling t", total=timesteps) + if verbose + else reversed(range(0, timesteps)) + ) + + if mask is not None: + assert x0 is not None + assert x0.shape[2:3] == mask.shape[2:3] # spatial size has to match + + for i in iterator: + ts = torch.full((b,), i, device=device, dtype=torch.long) + if self.shorten_cond_schedule: + assert self.model.conditioning_key != "hybrid" + tc = self.cond_ids[ts].to(cond.device) + cond = self.q_sample(x_start=cond, t=tc, noise=torch.randn_like(cond)) + + img = self.p_sample( + img, + cond, + ts, + clip_denoised=self.clip_denoised, + quantize_denoised=quantize_denoised, + ) + if mask is not None: + img_orig = self.q_sample(x0, ts) + img = img_orig * mask + (1.0 - mask) * img + + if i % log_every_t == 0 or i == timesteps - 1: + intermediates.append(img) + if callback: + callback(i) + if img_callback: + img_callback(img, i) + + if return_intermediates: + return img, intermediates + return img + + @torch.no_grad() + def sample( + self, + cond, + batch_size=16, + return_intermediates=False, + x_T=None, + verbose=True, + timesteps=None, + quantize_denoised=False, + mask=None, + x0=None, + shape=None, + **kwargs, + ): + if shape is None: + shape = (batch_size, self.channels, self.latent_t_size, self.latent_f_size) + if cond is not None: + if isinstance(cond, dict): + cond = { + key: cond[key][:batch_size] + if not isinstance(cond[key], list) + else list(map(lambda x: x[:batch_size], cond[key])) + for key in cond + } + else: + cond = ( + [c[:batch_size] for c in cond] + if isinstance(cond, list) + else cond[:batch_size] + ) + return self.p_sample_loop( + cond, + shape, + return_intermediates=return_intermediates, + x_T=x_T, + verbose=verbose, + timesteps=timesteps, + quantize_denoised=quantize_denoised, + mask=mask, + x0=x0, + **kwargs, + ) + + @torch.no_grad() + def sample_log( + self, + cond, + batch_size, + ddim, + ddim_steps, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + use_plms=False, + mask=None, + **kwargs, + ): + + if mask is not None: + shape = (self.channels, mask.size()[-2], mask.size()[-1]) + else: + shape = (self.channels, self.latent_t_size, self.latent_f_size) + + intermediate = None + if ddim and not use_plms: + # print("Use ddim sampler") + + ddim_sampler = DDIMSampler(self) + samples, intermediates = ddim_sampler.sample( + ddim_steps, + batch_size, + shape, + cond, + verbose=False, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + mask=mask, + **kwargs, + ) + + else: + # print("Use DDPM sampler") + samples, intermediates = self.sample( + cond=cond, + batch_size=batch_size, + return_intermediates=True, + unconditional_guidance_scale=unconditional_guidance_scale, + mask=mask, + unconditional_conditioning=unconditional_conditioning, + **kwargs, + ) + + return samples, intermediate + + @torch.no_grad() + def generate_sample( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + cond_key=self.cond_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, + ) + + if(torch.max(torch.abs(samples)) > 1e2): + samples = torch.clip(samples, min=-10, max=10) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if waveform.shape[0] > 1: + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform + + @torch.no_grad() + def generate_sample_masked( + self, + batchs, + ddim_steps=200, + ddim_eta=1.0, + x_T=None, + n_candidate_gen_per_text=1, + unconditional_guidance_scale=1.0, + unconditional_conditioning=None, + name="waveform", + use_plms=False, + time_mask_ratio_start_and_end=(0.25, 0.75), + freq_mask_ratio_start_and_end=(0.75, 1.0), + save=False, + **kwargs, + ): + # Generate n_candidate_gen_per_text times and select the best + # Batch: audio, text, fnames + assert x_T is None + try: + batchs = iter(batchs) + except TypeError: + raise ValueError("The first input argument should be an iterable object") + + if use_plms: + assert ddim_steps is not None + use_ddim = ddim_steps is not None + # waveform_save_path = os.path.join(self.get_log_dir(), name) + # os.makedirs(waveform_save_path, exist_ok=True) + # print("Waveform save path: ", waveform_save_path) + + with self.ema_scope("Generate"): + for batch in batchs: + z, c = self.get_input( + batch, + self.first_stage_key, + cond_key=self.cond_stage_key, + return_first_stage_outputs=False, + force_c_encode=True, + return_original_cond=False, + bs=None, + ) + text = super().get_input(batch, "text") + + # Generate multiple samples + batch_size = z.shape[0] * n_candidate_gen_per_text + + _, h, w = z.shape[0], z.shape[2], z.shape[3] + + mask = torch.ones(batch_size, h, w).to(self.device) + + mask[:, int(h * time_mask_ratio_start_and_end[0]) : int(h * time_mask_ratio_start_and_end[1]), :] = 0 + mask[:, :, int(w * freq_mask_ratio_start_and_end[0]) : int(w * freq_mask_ratio_start_and_end[1])] = 0 + mask = mask[:, None, ...] + + c = torch.cat([c] * n_candidate_gen_per_text, dim=0) + text = text * n_candidate_gen_per_text + + if unconditional_guidance_scale != 1.0: + unconditional_conditioning = ( + self.cond_stage_model.get_unconditional_condition(batch_size) + ) + + samples, _ = self.sample_log( + cond=c, + batch_size=batch_size, + x_T=x_T, + ddim=use_ddim, + ddim_steps=ddim_steps, + eta=ddim_eta, + unconditional_guidance_scale=unconditional_guidance_scale, + unconditional_conditioning=unconditional_conditioning, + use_plms=use_plms, mask=mask, x0=torch.cat([z] * n_candidate_gen_per_text) + ) + + mel = self.decode_first_stage(samples) + + waveform = self.mel_spectrogram_to_waveform(mel) + + if waveform.shape[0] > 1: + similarity = self.cond_stage_model.cos_similarity( + torch.FloatTensor(waveform).squeeze(1), text + ) + + best_index = [] + for i in range(z.shape[0]): + candidates = similarity[i :: z.shape[0]] + max_index = torch.argmax(candidates).item() + best_index.append(i + max_index * z.shape[0]) + + waveform = waveform[best_index] + # print("Similarity between generated audio and text", similarity) + # print("Choose the following indexes:", best_index) + + return waveform \ No newline at end of file diff --git a/audioldm/pipeline.py b/audioldm/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..b08e1f77206483025ce027588c2dea1de78ae26c --- /dev/null +++ b/audioldm/pipeline.py @@ -0,0 +1,301 @@ +import os + +import argparse +import yaml +import torch +from torch import autocast +from tqdm import tqdm, trange + +from audioldm import LatentDiffusion, seed_everything +from audioldm.utils import default_audioldm_config, get_duration, get_bit_depth, get_metadata, download_checkpoint +from audioldm.audio import wav_to_fbank, TacotronSTFT, read_wav_file +from audioldm.latent_diffusion.ddim import DDIMSampler +from einops import repeat +import os + +def make_batch_for_text_to_audio(text, waveform=None, fbank=None, batchsize=1): + text = [text] * batchsize + if batchsize < 1: + print("Warning: Batchsize must be at least 1. Batchsize is set to .") + + if(fbank is None): + fbank = torch.zeros((batchsize, 1024, 64)) # Not used, here to keep the code format + else: + fbank = torch.FloatTensor(fbank) + fbank = fbank.expand(batchsize, 1024, 64) + assert fbank.size(0) == batchsize + + stft = torch.zeros((batchsize, 1024, 512)) # Not used + + if(waveform is None): + waveform = torch.zeros((batchsize, 160000)) # Not used + else: + waveform = torch.FloatTensor(waveform) + waveform = waveform.expand(batchsize, -1) + assert waveform.size(0) == batchsize + + fname = [""] * batchsize # Not used + + batch = ( + fbank, + stft, + None, + fname, + waveform, + text, + ) + return batch + +def round_up_duration(duration): + return int(round(duration/2.5) + 1) * 2.5 + +def build_model( + ckpt_path=None, + config=None, + model_name="audioldm-s-full" +): + print("Load AudioLDM: %s", model_name) + + if(ckpt_path is None): + ckpt_path = get_metadata()[model_name]["path"] + + if(not os.path.exists(ckpt_path)): + download_checkpoint(model_name) + + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config(model_name) + + # Use text as condition instead of using waveform during training + config["model"]["params"]["device"] = device + config["model"]["params"]["cond_stage_key"] = "text" + + # No normalization here + latent_diffusion = LatentDiffusion(**config["model"]["params"]) + + resume_from_checkpoint = ckpt_path + + checkpoint = torch.load(resume_from_checkpoint, map_location=device) + latent_diffusion.load_state_dict(checkpoint["state_dict"]) + + latent_diffusion.eval() + latent_diffusion = latent_diffusion.to(device) + + latent_diffusion.cond_stage_model.embed_mode = "text" + return latent_diffusion + +def duration_to_latent_t_size(duration): + return int(duration * 25.6) + +def set_cond_audio(latent_diffusion): + latent_diffusion.cond_stage_key = "waveform" + latent_diffusion.cond_stage_model.embed_mode="audio" + return latent_diffusion + +def set_cond_text(latent_diffusion): + latent_diffusion.cond_stage_key = "text" + latent_diffusion.cond_stage_model.embed_mode="text" + return latent_diffusion + +def text_to_audio( + latent_diffusion, + text, + original_audio_file_path = None, + seed=42, + ddim_steps=200, + duration=10, + batchsize=1, + guidance_scale=2.5, + n_candidate_gen_per_text=3, + config=None, +): + seed_everything(int(seed)) + waveform = None + if(original_audio_file_path is not None): + waveform = read_wav_file(original_audio_file_path, int(duration * 102.4) * 160) + + batch = make_batch_for_text_to_audio(text, waveform=waveform, batchsize=batchsize) + + latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + + if(waveform is not None): + print("Generate audio that has similar content as %s" % original_audio_file_path) + latent_diffusion = set_cond_audio(latent_diffusion) + else: + print("Generate audio using text %s" % text) + latent_diffusion = set_cond_text(latent_diffusion) + + with torch.no_grad(): + waveform = latent_diffusion.generate_sample( + [batch], + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration, + ) + return waveform + +def style_transfer( + latent_diffusion, + text, + original_audio_file_path, + transfer_strength, + seed=42, + duration=10, + batchsize=1, + guidance_scale=2.5, + ddim_steps=200, + config=None, +): + if torch.cuda.is_available(): + device = torch.device("cuda:0") + else: + device = torch.device("cpu") + + assert original_audio_file_path is not None, "You need to provide the original audio file path" + + audio_file_duration = get_duration(original_audio_file_path) + + assert get_bit_depth(original_audio_file_path) == 16, "The bit depth of the original audio file %s must be 16" % original_audio_file_path + + # if(duration > 20): + # print("Warning: The duration of the audio file %s must be less than 20 seconds. Longer duration will result in Nan in model output (we are still debugging that); Automatically set duration to 20 seconds") + # duration = 20 + + if(duration >= audio_file_duration): + print("Warning: Duration you specified %s-seconds must equal or smaller than the audio file duration %ss" % (duration, audio_file_duration)) + duration = round_up_duration(audio_file_duration) + print("Set new duration as %s-seconds" % duration) + + # duration = round_up_duration(duration) + + latent_diffusion = set_cond_text(latent_diffusion) + + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + + seed_everything(int(seed)) + # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion.cond_stage_model.embed_mode = "text" + + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + mel, _, _ = wav_to_fbank( + original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT + ) + mel = mel.unsqueeze(0).unsqueeze(0).to(device) + mel = repeat(mel, "1 ... -> b ...", b=batchsize) + init_latent = latent_diffusion.get_first_stage_encoding( + latent_diffusion.encode_first_stage(mel) + ) # move to latent space, encode and sample + if(torch.max(torch.abs(init_latent)) > 1e2): + init_latent = torch.clip(init_latent, min=-10, max=10) + sampler = DDIMSampler(latent_diffusion) + sampler.make_schedule(ddim_num_steps=ddim_steps, ddim_eta=1.0, verbose=False) + + t_enc = int(transfer_strength * ddim_steps) + prompts = text + + with torch.no_grad(): + with autocast("cuda"): + with latent_diffusion.ema_scope(): + uc = None + if guidance_scale != 1.0: + uc = latent_diffusion.cond_stage_model.get_unconditional_condition( + batchsize + ) + + c = latent_diffusion.get_learned_conditioning([prompts] * batchsize) + z_enc = sampler.stochastic_encode( + init_latent, torch.tensor([t_enc] * batchsize).to(device) + ) + samples = sampler.decode( + z_enc, + c, + t_enc, + unconditional_guidance_scale=guidance_scale, + unconditional_conditioning=uc, + ) + # x_samples = latent_diffusion.decode_first_stage(samples) # Will result in Nan in output + # print(torch.sum(torch.isnan(samples))) + x_samples = latent_diffusion.decode_first_stage(samples) + # print(x_samples) + x_samples = latent_diffusion.decode_first_stage(samples[:,:,:-3,:]) + # print(x_samples) + waveform = latent_diffusion.first_stage_model.decode_to_waveform( + x_samples + ) + + return waveform + +def super_resolution_and_inpainting( + latent_diffusion, + text, + original_audio_file_path = None, + seed=42, + ddim_steps=200, + duration=None, + batchsize=1, + guidance_scale=2.5, + n_candidate_gen_per_text=3, + time_mask_ratio_start_and_end=(0.10, 0.15), # regenerate the 10% to 15% of the time steps in the spectrogram + # time_mask_ratio_start_and_end=(1.0, 1.0), # no inpainting + # freq_mask_ratio_start_and_end=(0.75, 1.0), # regenerate the higher 75% to 100% mel bins + freq_mask_ratio_start_and_end=(1.0, 1.0), # no super-resolution + config=None, +): + seed_everything(int(seed)) + if config is not None: + assert type(config) is str + config = yaml.load(open(config, "r"), Loader=yaml.FullLoader) + else: + config = default_audioldm_config() + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + # waveform = read_wav_file(original_audio_file_path, None) + mel, _, _ = wav_to_fbank( + original_audio_file_path, target_length=int(duration * 102.4), fn_STFT=fn_STFT + ) + + batch = make_batch_for_text_to_audio(text, fbank=mel[None,...], batchsize=batchsize) + + # latent_diffusion.latent_t_size = duration_to_latent_t_size(duration) + latent_diffusion = set_cond_text(latent_diffusion) + + with torch.no_grad(): + waveform = latent_diffusion.generate_sample_masked( + [batch], + unconditional_guidance_scale=guidance_scale, + ddim_steps=ddim_steps, + n_candidate_gen_per_text=n_candidate_gen_per_text, + duration=duration, + time_mask_ratio_start_and_end=time_mask_ratio_start_and_end, + freq_mask_ratio_start_and_end=freq_mask_ratio_start_and_end + ) + return waveform \ No newline at end of file diff --git a/audioldm/utils.py b/audioldm/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5401b29d4366774233f1bf4a9e7fcb7ce214187e --- /dev/null +++ b/audioldm/utils.py @@ -0,0 +1,281 @@ +import contextlib +import importlib + +from inspect import isfunction +import os +import soundfile as sf +import time +import wave + +import urllib.request +import progressbar + +CACHE_DIR = os.getenv( + "AUDIOLDM_CACHE_DIR", + os.path.join(os.path.expanduser("~"), ".cache/audioldm")) + +def get_duration(fname): + with contextlib.closing(wave.open(fname, 'r')) as f: + frames = f.getnframes() + rate = f.getframerate() + return frames / float(rate) + +def get_bit_depth(fname): + with contextlib.closing(wave.open(fname, 'r')) as f: + bit_depth = f.getsampwidth() * 8 + return bit_depth + +def get_time(): + t = time.localtime() + return time.strftime("%d_%m_%Y_%H_%M_%S", t) + +def seed_everything(seed): + import random, os + import numpy as np + import torch + + random.seed(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = True + + +def save_wave(waveform, savepath, name="outwav"): + if type(name) is not list: + name = [name] * waveform.shape[0] + + for i in range(waveform.shape[0]): + path = os.path.join( + savepath, + "%s_%s.wav" + % ( + os.path.basename(name[i]) + if (not ".wav" in name[i]) + else os.path.basename(name[i]).split(".")[0], + i, + ), + ) + print("Save audio to %s" % path) + sf.write(path, waveform[i, 0], samplerate=16000) + + +def exists(x): + return x is not None + + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params + + +def get_obj_from_str(string, reload=False): + module, cls = string.rsplit(".", 1) + if reload: + module_imp = importlib.import_module(module) + importlib.reload(module_imp) + return getattr(importlib.import_module(module, package=None), cls) + + +def instantiate_from_config(config): + if not "target" in config: + if config == "__is_first_stage__": + return None + elif config == "__is_unconditional__": + return None + raise KeyError("Expected key `target` to instantiate.") + return get_obj_from_str(config["target"])(**config.get("params", dict())) + + +def default_audioldm_config(model_name="audioldm-s-full"): + basic_config = { + "wave_file_save_path": "./output", + "id": { + "version": "v1", + "name": "default", + "root": "/mnt/fast/nobackup/users/hl01486/projects/general_audio_generation/AudioLDM-python/config/default/latent_diffusion.yaml", + }, + "preprocessing": { + "audio": {"sampling_rate": 16000, "max_wav_value": 32768}, + "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 8000, + "freqm": 0, + "timem": 0, + "blur": False, + "mean": -4.63, + "std": 2.74, + "target_length": 1024, + }, + }, + "model": { + "device": "cuda", + "target": "audioldm.pipline.LatentDiffusion", + "params": { + "base_learning_rate": 5e-06, + "linear_start": 0.0015, + "linear_end": 0.0195, + "num_timesteps_cond": 1, + "log_every_t": 200, + "timesteps": 1000, + "first_stage_key": "fbank", + "cond_stage_key": "waveform", + "latent_t_size": 256, + "latent_f_size": 16, + "channels": 8, + "cond_stage_trainable": True, + "conditioning_key": "film", + "monitor": "val/loss_simple_ema", + "scale_by_std": True, + "unet_config": { + "target": "audioldm.latent_diffusion.openaimodel.UNetModel", + "params": { + "image_size": 64, + "extra_film_condition_dim": 512, + "extra_film_use_concat": True, + "in_channels": 8, + "out_channels": 8, + "model_channels": 128, + "attention_resolutions": [8, 4, 2], + "num_res_blocks": 2, + "channel_mult": [1, 2, 3, 5], + "num_head_channels": 32, + "use_spatial_transformer": True, + }, + }, + "first_stage_config": { + "base_learning_rate": 4.5e-05, + "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "monitor": "val/rec_loss", + "image_key": "fbank", + "subband": 1, + "embed_dim": 8, + "time_shuffle": 1, + "ddconfig": { + "double_z": True, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + }, + }, + }, + "cond_stage_config": { + "target": "audioldm.clap.encoders.CLAPAudioEmbeddingClassifierFreev2", + "params": { + "key": "waveform", + "sampling_rate": 16000, + "embed_mode": "audio", + "unconditional_prob": 0.1, + }, + }, + }, + }, + } + + if("-l-" in model_name): + basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 256 + basic_config["model"]["params"]["unet_config"]["params"]["num_head_channels"] = 64 + elif("-m-" in model_name): + basic_config["model"]["params"]["unet_config"]["params"]["model_channels"] = 192 + basic_config["model"]["params"]["cond_stage_config"]["params"]["amodel"] = "HTSAT-base" # This model use a larger HTAST + + return basic_config + +def get_metadata(): + return { + "audioldm-s-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full.ckpt", + ), + "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1", + }, + "audioldm-l-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-l-full.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1", + }, + "audioldm-s-full-v2": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full-v2.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1", + }, + "audioldm-m-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1", + }, + "audioldm-s-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1", + }, + "audioldm-m-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-full.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1", + }, + } + +class MyProgressBar(): + def __init__(self): + self.pbar = None + + def __call__(self, block_num, block_size, total_size): + if not self.pbar: + self.pbar=progressbar.ProgressBar(maxval=total_size) + self.pbar.start() + + downloaded = block_num * block_size + if downloaded < total_size: + self.pbar.update(downloaded) + else: + self.pbar.finish() + +def download_checkpoint(checkpoint_name="audioldm-s-full"): + meta = get_metadata() + if(checkpoint_name not in meta.keys()): + print("The model name you provided is not supported. Please use one of the following: ", meta.keys()) + + if not os.path.exists(meta[checkpoint_name]["path"]) or os.path.getsize(meta[checkpoint_name]["path"]) < 2*10**9: + os.makedirs(os.path.dirname(meta[checkpoint_name]["path"]), exist_ok=True) + print(f"Downloading the main structure of {checkpoint_name} into {os.path.dirname(meta[checkpoint_name]['path'])}") + + urllib.request.urlretrieve(meta[checkpoint_name]["url"], meta[checkpoint_name]["path"], MyProgressBar()) + print( + "Weights downloaded in: {} Size: {}".format( + meta[checkpoint_name]["path"], + os.path.getsize(meta[checkpoint_name]["path"]), + ) + ) + \ No newline at end of file diff --git a/audioldm/variational_autoencoder/__init__.py b/audioldm/variational_autoencoder/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08b2a9b9698e02918d7b0dd9fe0431b2847e5aa2 --- /dev/null +++ b/audioldm/variational_autoencoder/__init__.py @@ -0,0 +1 @@ +from .autoencoder import AutoencoderKL \ No newline at end of file diff --git a/audioldm/variational_autoencoder/__pycache__/__init__.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..bb81d2d1b8f4e3d6bb161d5861ff6494c1bdecbf Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/__init__.cpython-310.pyc differ diff --git a/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f2ec08171f8ba332e1bc509325ef625248ec081e Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/autoencoder.cpython-310.pyc differ diff --git a/audioldm/variational_autoencoder/__pycache__/distributions.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/distributions.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6cb6aa80272bf44284bbf369f9cee9b2f48cc601 Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/distributions.cpython-310.pyc differ diff --git a/audioldm/variational_autoencoder/__pycache__/modules.cpython-310.pyc b/audioldm/variational_autoencoder/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..64a50a7a5bf52af461703ffc4c0a562297b38cff Binary files /dev/null and b/audioldm/variational_autoencoder/__pycache__/modules.cpython-310.pyc differ diff --git a/audioldm/variational_autoencoder/autoencoder.py b/audioldm/variational_autoencoder/autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..9dadc849da65d1f9eb82dc75dc777250bf738151 --- /dev/null +++ b/audioldm/variational_autoencoder/autoencoder.py @@ -0,0 +1,135 @@ +import torch +from audioldm.latent_diffusion.ema import * +from audioldm.variational_autoencoder.modules import Encoder, Decoder +from audioldm.variational_autoencoder.distributions import DiagonalGaussianDistribution + +from audioldm.hifigan.utilities import get_vocoder, vocoder_infer + + +class AutoencoderKL(nn.Module): + def __init__( + self, + ddconfig=None, + lossconfig=None, + image_key="fbank", + embed_dim=None, + time_shuffle=1, + subband=1, + ckpt_path=None, + reload_from_ckpt=None, + ignore_keys=[], + colorize_nlabels=None, + monitor=None, + base_learning_rate=1e-5, + scale_factor=1 + ): + super().__init__() + + self.encoder = Encoder(**ddconfig) + self.decoder = Decoder(**ddconfig) + + self.subband = int(subband) + + if self.subband > 1: + print("Use subband decomposition %s" % self.subband) + + self.quant_conv = torch.nn.Conv2d(2 * ddconfig["z_channels"], 2 * embed_dim, 1) + self.post_quant_conv = torch.nn.Conv2d(embed_dim, ddconfig["z_channels"], 1) + + self.vocoder = get_vocoder(None, "cpu") + self.embed_dim = embed_dim + + if monitor is not None: + self.monitor = monitor + + self.time_shuffle = time_shuffle + self.reload_from_ckpt = reload_from_ckpt + self.reloaded = False + self.mean, self.std = None, None + + self.scale_factor = scale_factor + + def encode(self, x): + # x = self.time_shuffle_operation(x) + x = self.freq_split_subband(x) + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + return posterior + + def decode(self, z): + z = self.post_quant_conv(z) + dec = self.decoder(z) + dec = self.freq_merge_subband(dec) + return dec + + def decode_to_waveform(self, dec): + dec = dec.squeeze(1).permute(0, 2, 1) + wav_reconstruction = vocoder_infer(dec, self.vocoder) + return wav_reconstruction + + def forward(self, input, sample_posterior=True): + posterior = self.encode(input) + if sample_posterior: + z = posterior.sample() + else: + z = posterior.mode() + + if self.flag_first_run: + print("Latent size: ", z.size()) + self.flag_first_run = False + + dec = self.decode(z) + + return dec, posterior + + def freq_split_subband(self, fbank): + if self.subband == 1 or self.image_key != "stft": + return fbank + + bs, ch, tstep, fbins = fbank.size() + + assert fbank.size(-1) % self.subband == 0 + assert ch == 1 + + return ( + fbank.squeeze(1) + .reshape(bs, tstep, self.subband, fbins // self.subband) + .permute(0, 2, 1, 3) + ) + + def freq_merge_subband(self, subband_fbank): + if self.subband == 1 or self.image_key != "stft": + return subband_fbank + assert subband_fbank.size(1) == self.subband # Channel dimension + bs, sub_ch, tstep, fbins = subband_fbank.size() + return subband_fbank.permute(0, 2, 1, 3).reshape(bs, tstep, -1).unsqueeze(1) + + def device(self): + return next(self.parameters()).device + + @torch.no_grad() + def encode_first_stage(self, x): + return self.encode(x) + + @torch.no_grad() + def decode_first_stage(self, z, predict_cids=False, force_not_quantize=False): + if predict_cids: + if z.dim() == 4: + z = torch.argmax(z.exp(), dim=1).long() + z = self.first_stage_model.quantize.get_codebook_entry(z, shape=None) + z = rearrange(z, "b h w c -> b c h w").contiguous() + + z = 1.0 / self.scale_factor * z + return self.decode(z) + + def get_first_stage_encoding(self, encoder_posterior): + if isinstance(encoder_posterior, DiagonalGaussianDistribution): + z = encoder_posterior.sample() + elif isinstance(encoder_posterior, torch.Tensor): + z = encoder_posterior + else: + raise NotImplementedError( + f"encoder_posterior of type '{type(encoder_posterior)}' not yet implemented" + ) + return self.scale_factor * z \ No newline at end of file diff --git a/audioldm/variational_autoencoder/distributions.py b/audioldm/variational_autoencoder/distributions.py new file mode 100644 index 0000000000000000000000000000000000000000..58eb535e7769f402169ddff77ee45c96ba3650d9 --- /dev/null +++ b/audioldm/variational_autoencoder/distributions.py @@ -0,0 +1,102 @@ +import torch +import numpy as np + + +class AbstractDistribution: + def sample(self): + raise NotImplementedError() + + def mode(self): + raise NotImplementedError() + + +class DiracDistribution(AbstractDistribution): + def __init__(self, value): + self.value = value + + def sample(self): + return self.value + + def mode(self): + return self.value + + +class DiagonalGaussianDistribution(object): + def __init__(self, parameters, deterministic=False): + self.parameters = parameters + self.mean, self.logvar = torch.chunk(parameters, 2, dim=1) + self.logvar = torch.clamp(self.logvar, -30.0, 20.0) + self.deterministic = deterministic + self.std = torch.exp(0.5 * self.logvar) + self.var = torch.exp(self.logvar) + if self.deterministic: + self.var = self.std = torch.zeros_like(self.mean).to( + device=self.parameters.device + ) + + def sample(self): + x = self.mean + self.std * torch.randn(self.mean.shape).to( + device=self.parameters.device + ) + return x + + def kl(self, other=None): + if self.deterministic: + return torch.Tensor([0.0]) + else: + if other is None: + return 0.5 * torch.mean( + torch.pow(self.mean, 2) + self.var - 1.0 - self.logvar, + dim=[1, 2, 3], + ) + else: + return 0.5 * torch.mean( + torch.pow(self.mean - other.mean, 2) / other.var + + self.var / other.var + - 1.0 + - self.logvar + + other.logvar, + dim=[1, 2, 3], + ) + + def nll(self, sample, dims=[1, 2, 3]): + if self.deterministic: + return torch.Tensor([0.0]) + logtwopi = np.log(2.0 * np.pi) + return 0.5 * torch.sum( + logtwopi + self.logvar + torch.pow(sample - self.mean, 2) / self.var, + dim=dims, + ) + + def mode(self): + return self.mean + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + source: https://github.com/openai/guided-diffusion/blob/27c20a8fab9cb472df5d6bdd6c8d11c8f430b924/guided_diffusion/losses.py#L12 + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, torch.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for torch.exp(). + logvar1, logvar2 = [ + x if isinstance(x, torch.Tensor) else torch.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + torch.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * torch.exp(-logvar2) + ) diff --git a/audioldm/variational_autoencoder/modules.py b/audioldm/variational_autoencoder/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..e48386d045c1d0e159de33db02af1035159c3447 --- /dev/null +++ b/audioldm/variational_autoencoder/modules.py @@ -0,0 +1,1066 @@ +# pytorch_diffusion + derived encoder decoder +import math +import torch +import torch.nn as nn +import numpy as np +from einops import rearrange + +from audioldm.utils import instantiate_from_config +from audioldm.latent_diffusion.attention import LinearAttention + + +def get_timestep_embedding(timesteps, embedding_dim): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: + From Fairseq. + Build sinusoidal embeddings. + This matches the implementation in tensor2tensor, but differs slightly + from the description in Section 3.5 of "Attention Is All You Need". + """ + assert len(timesteps.shape) == 1 + + half_dim = embedding_dim // 2 + emb = math.log(10000) / (half_dim - 1) + emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) + emb = emb.to(device=timesteps.device) + emb = timesteps.float()[:, None] * emb[None, :] + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) + if embedding_dim % 2 == 1: # zero pad + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def nonlinearity(x): + # swish + return x * torch.sigmoid(x) + + +def Normalize(in_channels, num_groups=32): + return torch.nn.GroupNorm( + num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True + ) + + +class Upsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class UpsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=1, padding=2 + ) + + def forward(self, x): + x = torch.nn.functional.interpolate(x, scale_factor=(4.0, 2.0), mode="nearest") + if self.with_conv: + x = self.conv(x) + return x + + +class Downsample(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=3, stride=2, padding=0 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) + return x + + +class DownsampleTimeStride4(nn.Module): + def __init__(self, in_channels, with_conv): + super().__init__() + self.with_conv = with_conv + if self.with_conv: + # Do time downsampling here + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=5, stride=(4, 2), padding=1 + ) + + def forward(self, x): + if self.with_conv: + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + else: + x = torch.nn.functional.avg_pool2d(x, kernel_size=(4, 2), stride=(4, 2)) + return x + + +class ResnetBlock(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout, + temb_channels=512, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + self.norm1 = Normalize(in_channels) + self.conv1 = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if temb_channels > 0: + self.temb_proj = torch.nn.Linear(temb_channels, out_channels) + self.norm2 = Normalize(out_channels) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = torch.nn.Conv2d( + out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + self.conv_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + else: + self.nin_shortcut = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x, temb): + h = x + h = self.norm1(h) + h = nonlinearity(h) + h = self.conv1(h) + + if temb is not None: + h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None] + + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) + + if self.in_channels != self.out_channels: + if self.use_conv_shortcut: + x = self.conv_shortcut(x) + else: + x = self.nin_shortcut(x) + + return x + h + + +class LinAttnBlock(LinearAttention): + """to match AttnBlock usage""" + + def __init__(self, in_channels): + super().__init__(dim=in_channels, heads=1, dim_head=in_channels) + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = Normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=1, stride=1, padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h * w).contiguous() + q = q.permute(0, 2, 1).contiguous() # b,hw,c + k = k.reshape(b, c, h * w).contiguous() # b,c,hw + w_ = torch.bmm(q, k).contiguous() # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] + w_ = w_ * (int(c) ** (-0.5)) + w_ = torch.nn.functional.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h * w).contiguous() + w_ = w_.permute(0, 2, 1).contiguous() # b,hw,hw (first hw of k, second of q) + h_ = torch.bmm( + v, w_ + ).contiguous() # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] + h_ = h_.reshape(b, c, h, w).contiguous() + + h_ = self.proj_out(h_) + + return x + h_ + + +def make_attn(in_channels, attn_type="vanilla"): + assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" + # print(f"making attention of type '{attn_type}' with {in_channels} in_channels") + if attn_type == "vanilla": + return AttnBlock(in_channels) + elif attn_type == "none": + return nn.Identity(in_channels) + else: + return LinAttnBlock(in_channels) + + +class Model(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + use_timestep=True, + use_linear_attn=False, + attn_type="vanilla", + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = self.ch * 4 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + + self.use_timestep = use_timestep + if self.use_timestep: + # timestep embedding + self.temb = nn.Module() + self.temb.dense = nn.ModuleList( + [ + torch.nn.Linear(self.ch, self.temb_ch), + torch.nn.Linear(self.temb_ch, self.temb_ch), + ] + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + skip_in = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + if i_block == self.num_res_blocks: + skip_in = ch * in_ch_mult[i_level] + block.append( + ResnetBlock( + in_channels=block_in + skip_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x, t=None, context=None): + # assert x.shape[2] == x.shape[3] == self.resolution + if context is not None: + # assume aligned context, cat along channel axis + x = torch.cat((x, context), dim=1) + if self.use_timestep: + # timestep embedding + assert t is not None + temb = get_timestep_embedding(t, self.ch) + temb = self.temb.dense[0](temb) + temb = nonlinearity(temb) + temb = self.temb.dense[1](temb) + else: + temb = None + + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block]( + torch.cat([h, hs.pop()], dim=1), temb + ) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + def get_last_layer(self): + return self.conv_out.weight + + +class Encoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + double_z=True, + use_linear_attn=False, + attn_type="vanilla", + downsample_time_stride4_levels=[], + **ignore_kwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # downsampling + self.conv_in = torch.nn.Conv2d( + in_channels, self.ch, kernel_size=3, stride=1, padding=1 + ) + + curr_res = resolution + in_ch_mult = (1,) + tuple(ch_mult) + self.in_ch_mult = in_ch_mult + self.down = nn.ModuleList() + for i_level in range(self.num_resolutions): + block = nn.ModuleList() + attn = nn.ModuleList() + block_in = ch * in_ch_mult[i_level] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + down = nn.Module() + down.block = block + down.attn = attn + if i_level != self.num_resolutions - 1: + if i_level in self.downsample_time_stride4_levels: + down.downsample = DownsampleTimeStride4(block_in, resamp_with_conv) + else: + down.downsample = Downsample(block_in, resamp_with_conv) + curr_res = curr_res // 2 + self.down.append(down) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, + 2 * z_channels if double_z else z_channels, + kernel_size=3, + stride=1, + padding=1, + ) + + def forward(self, x): + # timestep embedding + temb = None + # downsampling + hs = [self.conv_in(x)] + for i_level in range(self.num_resolutions): + for i_block in range(self.num_res_blocks): + h = self.down[i_level].block[i_block](hs[-1], temb) + if len(self.down[i_level].attn) > 0: + h = self.down[i_level].attn[i_block](h) + hs.append(h) + if i_level != self.num_resolutions - 1: + hs.append(self.down[i_level].downsample(hs[-1])) + + # middle + h = hs[-1] + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # end + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class Decoder(nn.Module): + def __init__( + self, + *, + ch, + out_ch, + ch_mult=(1, 2, 4, 8), + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + in_channels, + resolution, + z_channels, + give_pre_end=False, + tanh_out=False, + use_linear_attn=False, + downsample_time_stride4_levels=[], + attn_type="vanilla", + **ignorekwargs, + ): + super().__init__() + if use_linear_attn: + attn_type = "linear" + self.ch = ch + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.in_channels = in_channels + self.give_pre_end = give_pre_end + self.tanh_out = tanh_out + self.downsample_time_stride4_levels = downsample_time_stride4_levels + + if len(self.downsample_time_stride4_levels) > 0: + assert max(self.downsample_time_stride4_levels) < self.num_resolutions, ( + "The level to perform downsample 4 operation need to be smaller than the total resolution number %s" + % str(self.num_resolutions) + ) + + # compute in_ch_mult, block_in and curr_res at lowest res + in_ch_mult = (1,) + tuple(ch_mult) + block_in = ch * ch_mult[self.num_resolutions - 1] + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.z_shape = (1, z_channels, curr_res, curr_res) + # print("Working with z of shape {} = {} dimensions.".format( + # self.z_shape, np.prod(self.z_shape))) + + # z to block_in + self.conv_in = torch.nn.Conv2d( + z_channels, block_in, kernel_size=3, stride=1, padding=1 + ) + + # middle + self.mid = nn.Module() + self.mid.block_1 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) + self.mid.block_2 = ResnetBlock( + in_channels=block_in, + out_channels=block_in, + temb_channels=self.temb_ch, + dropout=dropout, + ) + + # upsampling + self.up = nn.ModuleList() + for i_level in reversed(range(self.num_resolutions)): + block = nn.ModuleList() + attn = nn.ModuleList() + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + if curr_res in attn_resolutions: + attn.append(make_attn(block_in, attn_type=attn_type)) + up = nn.Module() + up.block = block + up.attn = attn + if i_level != 0: + if i_level - 1 in self.downsample_time_stride4_levels: + up.upsample = UpsampleTimeStride4(block_in, resamp_with_conv) + else: + up.upsample = Upsample(block_in, resamp_with_conv) + curr_res = curr_res * 2 + self.up.insert(0, up) # prepend to get consistent order + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_ch, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, z): + # assert z.shape[1:] == self.z_shape[1:] + self.last_z_shape = z.shape + + # timestep embedding + temb = None + + # z to block_in + h = self.conv_in(z) + + # middle + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) + + # upsampling + for i_level in reversed(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.up[i_level].block[i_block](h, temb) + if len(self.up[i_level].attn) > 0: + h = self.up[i_level].attn[i_block](h) + if i_level != 0: + h = self.up[i_level].upsample(h) + + # end + if self.give_pre_end: + return h + + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + if self.tanh_out: + h = torch.tanh(h) + return h + + +class SimpleDecoder(nn.Module): + def __init__(self, in_channels, out_channels, *args, **kwargs): + super().__init__() + self.model = nn.ModuleList( + [ + nn.Conv2d(in_channels, in_channels, 1), + ResnetBlock( + in_channels=in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=2 * in_channels, + out_channels=4 * in_channels, + temb_channels=0, + dropout=0.0, + ), + ResnetBlock( + in_channels=4 * in_channels, + out_channels=2 * in_channels, + temb_channels=0, + dropout=0.0, + ), + nn.Conv2d(2 * in_channels, in_channels, 1), + Upsample(in_channels, with_conv=True), + ] + ) + # end + self.norm_out = Normalize(in_channels) + self.conv_out = torch.nn.Conv2d( + in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + for i, layer in enumerate(self.model): + if i in [1, 2, 3]: + x = layer(x, None) + else: + x = layer(x) + + h = self.norm_out(x) + h = nonlinearity(h) + x = self.conv_out(h) + return x + + +class UpsampleDecoder(nn.Module): + def __init__( + self, + in_channels, + out_channels, + ch, + num_res_blocks, + resolution, + ch_mult=(2, 2), + dropout=0.0, + ): + super().__init__() + # upsampling + self.temb_ch = 0 + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + block_in = in_channels + curr_res = resolution // 2 ** (self.num_resolutions - 1) + self.res_blocks = nn.ModuleList() + self.upsample_blocks = nn.ModuleList() + for i_level in range(self.num_resolutions): + res_block = [] + block_out = ch * ch_mult[i_level] + for i_block in range(self.num_res_blocks + 1): + res_block.append( + ResnetBlock( + in_channels=block_in, + out_channels=block_out, + temb_channels=self.temb_ch, + dropout=dropout, + ) + ) + block_in = block_out + self.res_blocks.append(nn.ModuleList(res_block)) + if i_level != self.num_resolutions - 1: + self.upsample_blocks.append(Upsample(block_in, True)) + curr_res = curr_res * 2 + + # end + self.norm_out = Normalize(block_in) + self.conv_out = torch.nn.Conv2d( + block_in, out_channels, kernel_size=3, stride=1, padding=1 + ) + + def forward(self, x): + # upsampling + h = x + for k, i_level in enumerate(range(self.num_resolutions)): + for i_block in range(self.num_res_blocks + 1): + h = self.res_blocks[i_level][i_block](h, None) + if i_level != self.num_resolutions - 1: + h = self.upsample_blocks[k](h) + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) + return h + + +class LatentRescaler(nn.Module): + def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): + super().__init__() + # residual block, interpolate, residual block + self.factor = factor + self.conv_in = nn.Conv2d( + in_channels, mid_channels, kernel_size=3, stride=1, padding=1 + ) + self.res_block1 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + self.attn = AttnBlock(mid_channels) + self.res_block2 = nn.ModuleList( + [ + ResnetBlock( + in_channels=mid_channels, + out_channels=mid_channels, + temb_channels=0, + dropout=0.0, + ) + for _ in range(depth) + ] + ) + + self.conv_out = nn.Conv2d( + mid_channels, + out_channels, + kernel_size=1, + ) + + def forward(self, x): + x = self.conv_in(x) + for block in self.res_block1: + x = block(x, None) + x = torch.nn.functional.interpolate( + x, + size=( + int(round(x.shape[2] * self.factor)), + int(round(x.shape[3] * self.factor)), + ), + ) + x = self.attn(x).contiguous() + for block in self.res_block2: + x = block(x, None) + x = self.conv_out(x) + return x + + +class MergedRescaleEncoder(nn.Module): + def __init__( + self, + in_channels, + ch, + resolution, + out_ch, + num_res_blocks, + attn_resolutions, + dropout=0.0, + resamp_with_conv=True, + ch_mult=(1, 2, 4, 8), + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + intermediate_chn = ch * ch_mult[-1] + self.encoder = Encoder( + in_channels=in_channels, + num_res_blocks=num_res_blocks, + ch=ch, + ch_mult=ch_mult, + z_channels=intermediate_chn, + double_z=False, + resolution=resolution, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + out_ch=None, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=intermediate_chn, + mid_channels=intermediate_chn, + out_channels=out_ch, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.encoder(x) + x = self.rescaler(x) + return x + + +class MergedRescaleDecoder(nn.Module): + def __init__( + self, + z_channels, + out_ch, + resolution, + num_res_blocks, + attn_resolutions, + ch, + ch_mult=(1, 2, 4, 8), + dropout=0.0, + resamp_with_conv=True, + rescale_factor=1.0, + rescale_module_depth=1, + ): + super().__init__() + tmp_chn = z_channels * ch_mult[-1] + self.decoder = Decoder( + out_ch=out_ch, + z_channels=tmp_chn, + attn_resolutions=attn_resolutions, + dropout=dropout, + resamp_with_conv=resamp_with_conv, + in_channels=None, + num_res_blocks=num_res_blocks, + ch_mult=ch_mult, + resolution=resolution, + ch=ch, + ) + self.rescaler = LatentRescaler( + factor=rescale_factor, + in_channels=z_channels, + mid_channels=tmp_chn, + out_channels=tmp_chn, + depth=rescale_module_depth, + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Upsampler(nn.Module): + def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): + super().__init__() + assert out_size >= in_size + num_blocks = int(np.log2(out_size // in_size)) + 1 + factor_up = 1.0 + (out_size % in_size) + print( + f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" + ) + self.rescaler = LatentRescaler( + factor=factor_up, + in_channels=in_channels, + mid_channels=2 * in_channels, + out_channels=in_channels, + ) + self.decoder = Decoder( + out_ch=out_channels, + resolution=out_size, + z_channels=in_channels, + num_res_blocks=2, + attn_resolutions=[], + in_channels=None, + ch=in_channels, + ch_mult=[ch_mult for _ in range(num_blocks)], + ) + + def forward(self, x): + x = self.rescaler(x) + x = self.decoder(x) + return x + + +class Resize(nn.Module): + def __init__(self, in_channels=None, learned=False, mode="bilinear"): + super().__init__() + self.with_conv = learned + self.mode = mode + if self.with_conv: + print( + f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" + ) + raise NotImplementedError() + assert in_channels is not None + # no asymmetric padding in torch conv, must do it ourselves + self.conv = torch.nn.Conv2d( + in_channels, in_channels, kernel_size=4, stride=2, padding=1 + ) + + def forward(self, x, scale_factor=1.0): + if scale_factor == 1.0: + return x + else: + x = torch.nn.functional.interpolate( + x, mode=self.mode, align_corners=False, scale_factor=scale_factor + ) + return x + + +class FirstStagePostProcessor(nn.Module): + def __init__( + self, + ch_mult: list, + in_channels, + pretrained_model: nn.Module = None, + reshape=False, + n_channels=None, + dropout=0.0, + pretrained_config=None, + ): + super().__init__() + if pretrained_config is None: + assert ( + pretrained_model is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.pretrained_model = pretrained_model + else: + assert ( + pretrained_config is not None + ), 'Either "pretrained_model" or "pretrained_config" must not be None' + self.instantiate_pretrained(pretrained_config) + + self.do_reshape = reshape + + if n_channels is None: + n_channels = self.pretrained_model.encoder.ch + + self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) + self.proj = nn.Conv2d( + in_channels, n_channels, kernel_size=3, stride=1, padding=1 + ) + + blocks = [] + downs = [] + ch_in = n_channels + for m in ch_mult: + blocks.append( + ResnetBlock( + in_channels=ch_in, out_channels=m * n_channels, dropout=dropout + ) + ) + ch_in = m * n_channels + downs.append(Downsample(ch_in, with_conv=False)) + + self.model = nn.ModuleList(blocks) + self.downsampler = nn.ModuleList(downs) + + def instantiate_pretrained(self, config): + model = instantiate_from_config(config) + self.pretrained_model = model.eval() + # self.pretrained_model.train = False + for param in self.pretrained_model.parameters(): + param.requires_grad = False + + @torch.no_grad() + def encode_with_pretrained(self, x): + c = self.pretrained_model.encode(x) + if isinstance(c, DiagonalGaussianDistribution): + c = c.mode() + return c + + def forward(self, x): + z_fs = self.encode_with_pretrained(x) + z = self.proj_norm(z_fs) + z = self.proj(z) + z = nonlinearity(z) + + for submodel, downmodel in zip(self.model, self.downsampler): + z = submodel(z, temb=None) + z = downmodel(z) + + if self.do_reshape: + z = rearrange(z, "b c h w -> b (h w) c") + return z diff --git a/configs/diffusion_model_config.json b/configs/diffusion_model_config.json new file mode 100644 index 0000000000000000000000000000000000000000..b38463e010df1ac433d9bb7326cf3bf4e82fa754 --- /dev/null +++ b/configs/diffusion_model_config.json @@ -0,0 +1,46 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 10, + 20 + ], + "block_out_channels": [ + 320, + 640, + 640, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 16, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 16, + "sample_size": [32, 2], + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} diff --git a/configs/diffusion_model_config_large.json b/configs/diffusion_model_config_large.json new file mode 100644 index 0000000000000000000000000000000000000000..41efea4aeb01b3beba167c834d67076878d67936 --- /dev/null +++ b/configs/diffusion_model_config_large.json @@ -0,0 +1,46 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 640, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 8, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 8, + "sample_size": [32, 2], + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} diff --git a/configs/diffusion_model_config_large_2048.json b/configs/diffusion_model_config_large_2048.json new file mode 100644 index 0000000000000000000000000000000000000000..1b3c3aa82591555a9a4231ab23565c19986be80f --- /dev/null +++ b/configs/diffusion_model_config_large_2048.json @@ -0,0 +1,46 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 640, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 2048, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 8, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 8, + "sample_size": [32, 2], + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} diff --git a/configs/diffusion_model_config_pretrain.json b/configs/diffusion_model_config_pretrain.json new file mode 100644 index 0000000000000000000000000000000000000000..c97fd7fb4588f44008c7fc65ec4d7a84f79ab66f --- /dev/null +++ b/configs/diffusion_model_config_pretrain.json @@ -0,0 +1,61 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.20.0.dev0", + "act_fn": "silu", + "addition_embed_type": null, + "addition_embed_type_num_heads": 64, + "addition_time_embed_dim": null, + "attention_head_dim": 8, + "block_out_channels": [ + 128, + 256, + 384, + 640 + ], + "center_input_sample": false, + "conv_in_kernel": 3, + "conv_out_kernel": 3, + "cross_attention_dim": 1024, + "cross_attention_norm": null, + "down_block_types": [ + "DownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "encoder_hid_dim": null, + "encoder_hid_dim_type": null, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 8, + "layers_per_block": 2, + "mid_block_only_cross_attention": null, + "mid_block_scale_factor": 1, + "mid_block_type": "UNetMidBlock2DCrossAttn", + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_attention_heads": null, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 8, + "resnet_out_scale_factor": 1.0, + "resnet_skip_time_act": false, + "resnet_time_scale_shift": "default", + "sample_size": 128, + "time_cond_proj_dim": null, + "time_embedding_act_fn": null, + "time_embedding_dim": null, + "time_embedding_type": "positional", + "timestep_post_act": null, + "transformer_layers_per_block": 1, + "up_block_types": [ + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "UpBlock2D" + ], + "upcast_attention": false, + "use_linear_projection": false +} diff --git a/configs/stable-diffusion-2-1.scheduler_48k.json b/configs/stable-diffusion-2-1.scheduler_48k.json new file mode 100644 index 0000000000000000000000000000000000000000..829265aba074379b79b6795bca25f450e5949b8c --- /dev/null +++ b/configs/stable-diffusion-2-1.scheduler_48k.json @@ -0,0 +1,14 @@ +{ + "_class_name": "DDPMScheduler", + "_diffusers_version": "0.8.0", + "beta_end": 0.02, + "beta_schedule": "scaled_linear", + "beta_start": 0.0015, + "clip_sample": false, + "num_train_timesteps": 1000, + "prediction_type": "v_prediction", + "set_alpha_to_one": false, + "skip_prk_steps": true, + "steps_offset": 1, + "trained_betas": null +} diff --git a/configs/stable_diffusion_2.1.json b/configs/stable_diffusion_2.1.json new file mode 100644 index 0000000000000000000000000000000000000000..9b1458658e8651398962171a8c5c56c5c0bd5aea --- /dev/null +++ b/configs/stable_diffusion_2.1.json @@ -0,0 +1,46 @@ +{ + "_class_name": "UNet2DConditionModel", + "_diffusers_version": "0.10.0.dev0", + "act_fn": "silu", + "attention_head_dim": [ + 5, + 10, + 20, + 20 + ], + "block_out_channels": [ + 320, + 640, + 1280, + 1280 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 4, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": null, + "only_cross_attention": false, + "out_channels": 4, + "sample_size": 96, + "up_block_types": [ + "UpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D", + "CrossAttnUpBlock2D" + ], + "use_linear_projection": true, + "upcast_attention": true +} diff --git a/configs/stable_diffusion_sdxl_scheduler_config.json b/configs/stable_diffusion_sdxl_scheduler_config.json new file mode 100644 index 0000000000000000000000000000000000000000..e5bc8421e047838523be7acfb6720f167f7382f6 --- /dev/null +++ b/configs/stable_diffusion_sdxl_scheduler_config.json @@ -0,0 +1,18 @@ +{ + "_class_name": "EulerDiscreteScheduler", + "_diffusers_version": "0.19.0.dev0", + "beta_end": 0.012, + "beta_schedule": "scaled_linear", + "beta_start": 0.00085, + "clip_sample": false, + "interpolation_type": "linear", + "num_train_timesteps": 1000, + "prediction_type": "epsilon", + "sample_max_value": 1.0, + "set_alpha_to_one": false, + "skip_prk_steps": true, + "steps_offset": 1, + "timestep_spacing": "leading", + "trained_betas": null, + "use_karras_sigmas": false +} diff --git a/data/-26aVYRtEAc_000030.mp4 b/data/-26aVYRtEAc_000030.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..1eb4d84fd3980f36ae1903cb970d8935bad678cc --- /dev/null +++ b/data/-26aVYRtEAc_000030.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5d14a1a521462ccba17f07c8abe5069d91e10033dd701f17b0a9f46c49039703 +size 1037905 diff --git a/data/-BAKe6QGTUk_000030.mp4 b/data/-BAKe6QGTUk_000030.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..c0f4b5b2e27d52ee31685e322aa0cabcbe89450c Binary files /dev/null and b/data/-BAKe6QGTUk_000030.mp4 differ diff --git a/data/-yoaSondvkw_000071.mp4 b/data/-yoaSondvkw_000071.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..321f50e97456eb3d3df0d9a81217339435a1f01f --- /dev/null +++ b/data/-yoaSondvkw_000071.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d27b48c1b1c0c6ac5ad04c73f1a035d58ac7fd94e47f90e3b55539aecf2560d5 +size 1989725 diff --git a/data/0Bp8c3PfAAA_000053.mp4 b/data/0Bp8c3PfAAA_000053.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..0ab685cdbb465fdf0c810ef92a2b72126bba8ae6 --- /dev/null +++ b/data/0Bp8c3PfAAA_000053.mp4 @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9cd18cf3c1cff720d7cdc4fec588d0f75294481cf6a69dd8d8b8f3d9edfd0c84 +size 2485403 diff --git a/data/0DCit2EBtjs_000030.mp4 b/data/0DCit2EBtjs_000030.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..9b86f756b4017551ab41b896ef86445edd344336 Binary files /dev/null and b/data/0DCit2EBtjs_000030.mp4 differ diff --git a/inference_from_video.py b/inference_from_video.py new file mode 100644 index 0000000000000000000000000000000000000000..96cbbbd7c020ea6580aa894d5c8a344af7ba30e5 --- /dev/null +++ b/inference_from_video.py @@ -0,0 +1,221 @@ +import os +import copy +import json +import time +import torch +import argparse +from PIL import Image +import numpy as np +import soundfile as sf +#import wandb +from tqdm import tqdm +from diffusers import DDPMScheduler +from models import build_pretrained_models, AudioDiffusion +from transformers import AutoProcessor, ClapModel +import torchaudio +import tools.torch_tools as torch_tools +from datasets import load_dataset + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + +def chunks(lst, n): + """Yield successive n-sized chunks from lst.""" + for i in range(0, len(lst), n): + yield lst[i:i + n] + +def parse_args(): + parser = argparse.ArgumentParser(description="Inference for text to audio generation task.") + parser.add_argument( + "--original_args", type=str, default=None, + help="Path for summary jsonl file saved during training." + ) + parser.add_argument( + "--model", type=str, default=None, + help="Path for saved model bin file." + ) + parser.add_argument( + "--vae_model", type=str, default="audioldm-s-full", + help="Path for saved model bin file." + ) + parser.add_argument( + "--num_steps", type=int, default=200, + help="How many denoising steps for generation.", + ) + parser.add_argument( + "--guidance", type=float, default=3, + help="Guidance scale for classifier free guidance." + ) + parser.add_argument( + "--batch_size", type=int, default=1, + help="Batch size for generation.", + ) + parser.add_argument( + "--num_samples", type=int, default=1, + help="How many samples per prompt.", + ) + parser.add_argument( + "--num_test_instances", type=int, default=-1, + help="How many test instances to evaluate.", + ) + parser.add_argument( + "--sample_rate", type=int, default=-1, + help="How many test instances to evaluate.", + ) + parser.add_argument( + "--save_dir", type=str, default="./outputs/tmp", + help="output save dir" + ) + parser.add_argument( + "--data_path", type=str, default="data/video_processed/video_gt_augment", + help="inference data path" + ) + + args = parser.parse_args() + + return args + +def main(): + args = parse_args() + + train_args = dotdict(json.loads(open(args.original_args).readlines()[0])) + if "hf_model" not in train_args: + train_args["hf_model"] = None + + # Load Models # + name = train_args.vae_model + vae, stft = build_pretrained_models(name) + vae, stft = vae.cuda(), stft.cuda() + model_class = AudioDiffusion + if train_args.ib: + print("*****USING MODEL IMAGEBIND*****") + from models_imagebind import AudioDiffusion_IB + model_class = AudioDiffusion if not train_args.ib else AudioDiffusion_IB + elif train_args.lb: + print("*****USING MODEL LANGUAGEBIND*****") + from models_languagebind import AudioDiffusion_LB + model_class = AudioDiffusion_LB + elif train_args.jepa: + print("*****USING MODEL JEPA*****") + from models_vjepa import AudioDiffusion_JEPA + model_class = AudioDiffusion_JEPA + + model = model_class( + train_args.fea_encoder_name, + train_args.scheduler_name, + train_args.unet_model_name, + train_args.unet_model_config, + train_args.snr_gamma, + train_args.freeze_text_encoder, + train_args.uncondition, + train_args.img_pretrained_model_path, + train_args.task, + train_args.embedding_dim, + train_args.pe + ) + + model.eval() + + # Load Trained Weight # + device = torch.device("cuda:0") #vae.device() + if args.model.endswith(".pt") or args.model.endswith(".bin"): + model.load_state_dict(torch.load(args.model), strict=False) + else: + from safetensors.torch import load_model + load_model(model, args.model, strict=False) + + model.to(device) + + scheduler = DDPMScheduler.from_pretrained(train_args.scheduler_name, subfolder="scheduler") + sample_rate = args.sample_rate + #evaluator = EvaluationHelper(16000, "cuda:0") + + + def audio_text_matching(waveforms, text, sample_freq=24000, max_len_in_seconds=10): + new_freq = 48000 + resampled = [] + + for wav in waveforms: + x = torchaudio.functional.resample(torch.tensor(wav, dtype=torch.float).reshape(1, -1), orig_freq=sample_freq, new_freq=new_freq)[0].numpy() + resampled.append(x[:new_freq*max_len_in_seconds]) + + inputs = clap_processor(text=text, audios=resampled, return_tensors="pt", padding=True, sampling_rate=48000) + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = clap(**inputs) + + logits_per_audio = outputs.logits_per_audio + ranks = torch.argsort(logits_per_audio.flatten(), descending=True).cpu().numpy() + return ranks + + # Load Data # + if train_args.prefix: + prefix = train_args.prefix + else: + prefix = "" + + # data_path = "data/video_test/" + data_path = args.data_path + wavname = [f"{name.split('.')[0]}.wav" for name in os.listdir(data_path)] + video_features = [] + for video_file in os.listdir(data_path): + video_path = os.path.join(data_path, video_file) + video_feature = torch_tools.load_video(video_path, frame_rate=2, size=224) + print(video_feature.shape) + video_features.append(video_feature) + + # Generate # + num_steps, guidance, batch_size, num_samples = args.num_steps, args.guidance, args.batch_size, args.num_samples + all_outputs = [] + + for k in tqdm(range(0, len(wavname), batch_size)): + + with torch.no_grad(): + # if train_args.task == 'image2audio': + # prompt = text_prompts[k: k+batch_size] + # imgs = [] + # for img_path in prompt: + # img = Image.open(img_path) + # imgs.append(np.array(img)) + # prompt = imgs + # elif train_args.task == 'video2audio': + prompt = video_features[k: k+batch_size] + + latents = model.inference(scheduler, None, prompt, None, num_steps, guidance, num_samples, disable_progress=True, device=device) + mel = vae.decode_first_stage(latents) + wave = vae.decode_to_waveform(mel) + + all_outputs += [item for item in wave] + + # Save # + exp_id = str(int(time.time())) + if not os.path.exists("outputs"): + os.makedirs("outputs") + + if num_samples == 1: + output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}_augment".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate) + os.makedirs(output_dir, exist_ok=True) + for j, wav in enumerate(all_outputs): + sf.write("{}/{}".format(output_dir, wavname[j]), wav, samplerate=sample_rate) + + else: + for i in range(num_samples): + output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}/rank_{}".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate, i+1) + os.makedirs(output_dir, exist_ok=True) + + groups = list(chunks(all_outputs, num_samples)) + for k in tqdm(range(len(groups))): + wavs_for_text = groups[k] + rank = audio_text_matching(wavs_for_text, text_prompts[k]) + ranked_wavs_for_text = [wavs_for_text[r] for r in rank] + + for i, wav in enumerate(ranked_wavs_for_text): + output_dir = "{}/{}_{}_steps_{}_guidance_{}_sampleRate_{}/rank_{}".format(args.save_dir, exp_id, "_".join(args.model.split("/")[1:-1]), num_steps, guidance, sample_rate, i+1) + sf.write("{}/{}".format(output_dir, wavname[k]), wav, samplerate=sample_rate) + +if __name__ == "__main__": + main() diff --git a/inference_from_video.sh b/inference_from_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..027cd593387d7ab62f93881870503cbf12dbc44e --- /dev/null +++ b/inference_from_video.sh @@ -0,0 +1,21 @@ +steps=300 +guidance=3 +num_samples=1 +model="vta-ldm-clip4clip-v-large" +# model="vta_ldm_clip4clip_augment_pe" +# model="vta_ldm_clip4clip_augment_ib" +# model="vta_ldm_clip4clip_ib" +# model="vta_ldm_clip4clip_lb" +# model="vta_ldm_clip4clip_pe" +# model="vta_ldm_clip4clip_text" +# model="vta_ldm_youtube" +# model="vta_ldm_vjepa" +CUDA_VISIBLE_DEVICES=2 python3.10 inference_from_video.py --original_args="ckpt/$model/summary.jsonl" \ +--model="ckpt/$model/pytorch_model_2.bin" \ +--sample_rate 16000 \ +--data_path "data" \ +--save_dir outputs/$model \ +--num_steps $steps \ +--guidance $guidance \ +--num_samples $num_samples \ +--batch_size 8 diff --git a/models.py b/models.py new file mode 100644 index 0000000000000000000000000000000000000000..d669dc254ef433825092fe7d1bcdc2d148604456 --- /dev/null +++ b/models.py @@ -0,0 +1,616 @@ +import random +import numpy as np +from tqdm import tqdm +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +from einops import repeat +import time +from tools.torch_tools import wav_to_fbank, sinusoidal_positional_embedding + +from audioldm.audio.stft import TacotronSTFT +from audioldm.variational_autoencoder import AutoencoderKL +from audioldm.utils import default_audioldm_config, get_metadata + +from transformers import CLIPTokenizer, AutoTokenizer, T5Tokenizer +from transformers import CLIPTextModel, T5EncoderModel, AutoModel +from transformers import CLIPVisionModelWithProjection, CLIPTextModelWithProjection +from transformers import CLIPProcessor, CLIPModel + +import diffusers +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DDPMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL as DiffuserAutoencoderKL +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, RandomResizedCrop +from diffusers import AudioLDMPipeline + +def build_pretrained_models(name): + checkpoint = torch.load(name, map_location="cpu") + scale_factor = checkpoint["state_dict"]["scale_factor"].item() + + vae_state_dict = {k[18:]: v for k, v in checkpoint["state_dict"].items() if "first_stage_model." in k} + + config = default_audioldm_config(name) + vae_config = config["model"]["params"]["first_stage_config"]["params"] + vae_config["scale_factor"] = scale_factor + + vae = AutoencoderKL(**vae_config) + vae.load_state_dict(vae_state_dict) + + fn_STFT = TacotronSTFT( + config["preprocessing"]["stft"]["filter_length"], + config["preprocessing"]["stft"]["hop_length"], + config["preprocessing"]["stft"]["win_length"], + config["preprocessing"]["mel"]["n_mel_channels"], + config["preprocessing"]["audio"]["sampling_rate"], + config["preprocessing"]["mel"]["mel_fmin"], + config["preprocessing"]["mel"]["mel_fmax"], + ) + + vae.eval() + fn_STFT.eval() + return vae, fn_STFT + + +class EffNetb3(nn.Module): + def __init__(self, pretrained_model_path, embedding_dim=1024, pretrained=True): + super(EffNetb3, self).__init__() + self.model_name = 'effnetb3' + self.pretrained = pretrained + # Create model + # self.effnet = torch.hub.load('rwightman/gen-efficientnet-pytorch', 'efficientnet_b3', pretrained=self.pretrained) + # torch.save(self.effnet, 'model.pth') + self.effnet = torch.hub.load(pretrained_model_path, 'efficientnet_b3', trust_repo=True, source='local') + #self.effnet.conv_stem = nn.Conv2d(1, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False) + self.embedder = nn.Conv2d(384, embedding_dim, kernel_size=1, stride=1, padding=0) + + def forward(self, x): + #out = self.effnet(x) + out = self.effnet.conv_stem(x) + out = self.effnet.bn1(out) + out = self.effnet.act1(out) + for i in range(len(self.effnet.blocks)): + out = self.effnet.blocks[i](out) + out = self.embedder(out) + return out + + +class EffNetb3_last_layer(nn.Module): + def __init__(self, pretrained_model_path, embedding_dim=1024, pretrained=True): + super(EffNetb3_last_layer, self).__init__() + self.model_name = 'effnetb3' + self.pretrained = pretrained + self.effnet = torch.hub.load(pretrained_model_path, 'efficientnet_b3', trust_repo=True, source='local') + self.effnet.classifier = nn.Linear(1536, embedding_dim) + + def forward(self, x): + out = self.effnet(x) + return out.unsqueeze(-1) + + +class Clip4Video(nn.Module): + def __init__(self, model, embedding_dim=1024, pretrained=True, pe=False): + super(Clip4Video, self).__init__() + self.pretrained = pretrained + self.clip_vision = CLIPVisionModelWithProjection.from_pretrained(model) + self.clip_text = CLIPTextModelWithProjection.from_pretrained(model) + self.tokenizer = AutoTokenizer.from_pretrained(model) + + input_dim = 512 if "clip-vit-base" in model else 768 + self.linear_layer = nn.Linear(input_dim, embedding_dim) + self.pe = sinusoidal_positional_embedding(30, input_dim) if pe else None + print("*****PE*****") if pe else print("*****W/O PE*****") + + def forward(self, text=None, image=None, video=None): + assert text is not None or image is not None or video is not None, "At least one of text, image or video should be provided" + if text is not None and video is None: + inputs = self.tokenizer([text], padding=True, truncation=True, return_tensors="pt", max_length=77).to(self.clip_text.device) + out = self.clip_text(**inputs) + out = out.text_embeds.repeat(20, 1) + elif video is not None and text is None: + out = self.clip_vision(video.to(self.clip_vision.device)) # input video x: t * 3 * w * h + out = out.image_embeds # t * 512 + if self.pe is not None: + out = out + self.pe[:out.shape[0], :].to(self.clip_vision.device) + # out['last_hidden_state'].shape # t * 50 * 768 + # out['image_embeds'].shape # t * 512 + elif text is not None and video is not None: + text_inputs = self.tokenizer([text], padding=True, truncation=True, return_tensors="pt", max_length=77).to(self.clip_text.device) + video_out = self.clip_vision(video.to(self.clip_vision.device)) + video_out = video_out.image_embeds + text_out = self.clip_text(**text_inputs) + text_out = text_out.text_embeds.repeat(video_out.shape[0], 1) + # out = text_out + video_out + # concat + out = torch.cat([text_out, video_out], dim=0) + out = self.linear_layer(out) # t * 1024 + return out + + +class AudioDiffusion(nn.Module): + def __init__( + self, + fea_encoder_name, + scheduler_name, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + freeze_text_encoder=True, + uncondition=False, + img_pretrained_model_path=None, + task=None, + embedding_dim=1024, + pe=False + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.fea_encoder_name = fea_encoder_name + self.scheduler_name = scheduler_name + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.freeze_text_encoder = freeze_text_encoder + self.uncondition = uncondition + self.task = task + self.pe = pe + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.noise_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler") + self.inference_scheduler = DDPMScheduler.from_pretrained(self.scheduler_name, subfolder="scheduler") + + if unet_model_config_path: + unet_config = UNet2DConditionModel.load_config(unet_model_config_path) + print("unet_config", unet_config) + self.unet = UNet2DConditionModel.from_config(unet_config, subfolder="unet") + self.set_from = "random" + print("UNet initialized randomly.") + else: + self.unet = UNet2DConditionModel.from_pretrained(unet_model_name, subfolder="unet") + self.set_from = "pre-trained" + self.group_in = nn.Sequential(nn.Linear(8, 512), nn.Linear(512, 4)) + self.group_out = nn.Sequential(nn.Linear(4, 512), nn.Linear(512, 8)) + print("UNet initialized from stable diffusion checkpoint.") + + if self.task == "text2audio": + if "stable-diffusion" in self.fea_encoder_name: + self.tokenizer = CLIPTokenizer.from_pretrained(self.fea_encoder_name, subfolder="tokenizer") + self.text_encoder = CLIPTextModel.from_pretrained(self.fea_encoder_name, subfolder="text_encoder") + elif "t5" in self.fea_encoder_name and "Chinese" not in self.fea_encoder_name: + self.tokenizer = AutoTokenizer.from_pretrained(self.fea_encoder_name) + self.text_encoder = T5EncoderModel.from_pretrained(self.fea_encoder_name) + elif "Chinese" in self.fea_encoder_name: + self.tokenizer = T5Tokenizer.from_pretrained(self.fea_encoder_name) + self.text_encoder = T5EncoderModel.from_pretrained(self.fea_encoder_name) + elif "clap" in self.fea_encoder_name: + self.tokenizer = RobertaTokenizer.from_pretrained("roberta-base") + self.CLAP_model = laion_clap.CLAP_Module(enable_fusion=False) + self.CLAP_model.load_ckpt(self.fea_encoder_name) + elif "clip-vit" in self.fea_encoder_name: + # self.CLIP_model = CLIPModel.from_pretrained(self.fea_encoder_name) + # self.CLIP_processor = CLIPProcessor.from_pretrained(self.fea_encoder_name) + self.CLIP_model = CLIPTextModelWithProjection.from_pretrained(self.fea_encoder_name) + self.tokenizer = AutoTokenizer.from_pretrained(self.fea_encoder_name) + if "base" in self.fea_encoder_name: + self.linear_layer = nn.Linear(512, embedding_dim) + else: + self.linear_layer = nn.Linear(768, embedding_dim) + else: + self.tokenizer = AutoTokenizer.from_pretrained(self.fea_encoder_name) + self.text_encoder = AutoModel.from_pretrained(self.fea_encoder_name) + elif self.task == "image2audio": + if "clip-vit" in self.fea_encoder_name: + self.CLIP_model = CLIPModel.from_pretrained(self.fea_encoder_name) + self.CLIP_processor = CLIPProcessor.from_pretrained(self.fea_encoder_name) + self.linear_layer = nn.Linear(512, embedding_dim) + # self.img_fea_extractor = EffNetb3(img_pretrained_model_path) + else: + self.img_fea_extractor = EffNetb3_last_layer(img_pretrained_model_path) + elif self.task == "video2audio": + self.vid_fea_extractor = Clip4Video(model=self.fea_encoder_name, embedding_dim=embedding_dim, pe=pe) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def encode_text(self, prompt): + device = self.text_encoder.device + batch = self.tokenizer( + prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + ) + input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) + + if self.freeze_text_encoder: + with torch.no_grad(): + encoder_hidden_states = self.text_encoder( + input_ids=input_ids, attention_mask=attention_mask + )[0] + else: + encoder_hidden_states = self.text_encoder( + input_ids=input_ids, attention_mask=attention_mask + )[0] + + boolean_encoder_mask = (attention_mask == 1).to(device) + return encoder_hidden_states, boolean_encoder_mask + + def encode_text_CLAP(self, prompt): + device = self.text_encoder.device + batch = self.tokenizer(prompt, padding="max_length", truncation=True, max_length=self.tokenizer.model_max_length, return_tensors="pt") + input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) + + if self.freeze_text_encoder: + with torch.no_grad(): + encoder_hidden_states = self.CLAP_model.model.get_text_embedding(prompt) + else: + encoder_hidden_states = self.CLAP_model.model.get_text_embedding(prompt) + + boolean_encoder_mask = (attention_mask == 1).to(device) + return encoder_hidden_states, boolean_encoder_mask + + def encode_image(self, prompt, device): + if "clip-vit" in self.fea_encoder_name: + with torch.no_grad(): + inputs = self.CLIP_processor(text=["aaa"], images=prompt, return_tensors="pt", padding=True).to(device) + encoder_hidden_states = self.CLIP_model(**inputs).image_embeds + encoder_hidden_states = self.linear_layer(encoder_hidden_states) # b * 1024 + encoder_hidden_states = encoder_hidden_states.unsqueeze(1).to(device) + else: + img_fea = self.img_fea_extractor(prompt) + encoder_hidden_states = img_fea.view(img_fea.shape[0], img_fea.shape[1], -1).permute(0, 2, 1) + boolean_encoder_mask = torch.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.bool) + boolean_encoder_mask = boolean_encoder_mask.to(device) + + return encoder_hidden_states, boolean_encoder_mask + + def encode_video(self, video_batch, text=None, device=None): + vid_feas = [] + for i, video in enumerate(video_batch): + if text: + vid_fea = self.vid_fea_extractor(video=video, text=text[i]) # t * fea_dim + else: + vid_fea = self.vid_fea_extractor(video=video) + vid_feas.append(vid_fea) + + padding = 0 + size = max(v.size(0) for v in vid_feas) + batch_size = len(vid_feas) + embed_size = vid_feas[0].size(1) + encoder_hidden_states = vid_feas[0].new(batch_size, size, embed_size).fill_(padding) + boolean_encoder_mask = torch.ones((batch_size, size), dtype=torch.bool) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + dst.copy_(src) + + for i, v in enumerate(vid_feas): + copy_tensor(v, encoder_hidden_states[i][: len(v)]) + boolean_encoder_mask[i, len(v):] = False + return encoder_hidden_states.to(device), boolean_encoder_mask.to(device) + + def encode_text_CLIP(self, prompt, device): + # tmp_image = np.ones((512, 512, 3)) + # with torch.no_grad(): + # inputs = self.CLIP_processor(text=prompt, images=tmp_image, return_tensors="pt", padding=True, max_length=77, truncation=True).to(device) + # encoder_hidden_states = self.CLIP_model(**inputs).text_embeds # b * 768 + text_inputs = self.tokenizer(prompt, padding=True, truncation=True, return_tensors="pt", max_length=77).to(device) + encoder_hidden_states = self.CLIP_model(**text_inputs).text_embeds + encoder_hidden_states = self.linear_layer(encoder_hidden_states) # b * 1024 + encoder_hidden_states = encoder_hidden_states.unsqueeze(1).to(device) + boolean_encoder_mask = torch.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.bool) + boolean_encoder_mask = boolean_encoder_mask.to(device) + + return encoder_hidden_states, boolean_encoder_mask + + def forward(self, latents, text=None, video=None, image=None, validation_mode=False, device=None): + num_train_timesteps = self.noise_scheduler.num_train_timesteps + self.noise_scheduler.set_timesteps(num_train_timesteps, device=device) + # encoder_hidden_states.shape [b, t, f] + if self.task == "text2audio": + if "clip-vit" in self.fea_encoder_name: + encoder_hidden_states, boolean_encoder_mask = self.encode_text_CLIP(text, device) + else: + encoder_hidden_states, boolean_encoder_mask = self.encode_text(text) + if self.uncondition: + mask_indices = [k for k in range(len(text)) if random.random() < 0.1] + # mask_indices = [k for k in range(len(prompt))] + if len(mask_indices) > 0: + encoder_hidden_states[mask_indices] = 0 + elif self.task == "image2audio": + encoder_hidden_states, boolean_encoder_mask = self.encode_image(image, device=device) + elif self.task == "video2audio": + encoder_hidden_states, boolean_encoder_mask = self.encode_video(video, text, device=device) + + bsz = latents.shape[0] + if validation_mode: + timesteps = (self.noise_scheduler.num_train_timesteps//2) * torch.ones((bsz,), dtype=torch.int64, device=device) + else: + # Sample a random timestep for each instance + timesteps = torch.randint(0, self.noise_scheduler.num_train_timesteps, (bsz,), device=device) + timesteps = timesteps.long() + + noise = torch.randn_like(latents) + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + + # Get the target for loss depending on the prediction type + if self.noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif self.noise_scheduler.config.prediction_type == "v_prediction": + target = self.noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type {self.noise_scheduler.config.prediction_type}") + + if self.set_from == "random": + model_pred = self.unet( + noisy_latents, timesteps, encoder_hidden_states, + encoder_attention_mask=boolean_encoder_mask + ).sample + + elif self.set_from == "pre-trained": + compressed_latents = self.group_in(noisy_latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() + model_pred = self.unet( + compressed_latents, timesteps, encoder_hidden_states, + encoder_attention_mask=boolean_encoder_mask + ).sample + model_pred = self.group_out(model_pred.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() + + if self.snr_gamma is None: + loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean") + else: + # Compute loss-weights as per Section 3.4 of https://arxiv.org/abs/2303.09556. + # Adaptef from huggingface/diffusers/blob/main/examples/text_to_image/train_text_to_image.py + snr = self.compute_snr(timesteps) + mse_loss_weights = ( + torch.stack([snr, self.snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + ) + loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") + loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights + loss = loss.mean() + + return loss + + @torch.no_grad() + def inference(self, inference_scheduler, text=None, video=None, image=None, num_steps=20, guidance_scale=3, num_samples_per_prompt=1, + disable_progress=True, device=None): + start = time.time() + classifier_free_guidance = guidance_scale > 1.0 + + #print("ldm time 0", time.time()-start, prompt) + if self.task == "text2audio": + batch_size = len(text) * num_samples_per_prompt + + if classifier_free_guidance: + if "clip-vit" in self.fea_encoder_name: + encoder_hidden_states, boolean_encoder_mask = self.encode_text_clip_classifier_free(text, num_samples_per_prompt, device=device) + else: + encoder_hidden_states, boolean_encoder_mask = self.encode_text_classifier_free(text, num_samples_per_prompt) + else: + encoder_hidden_states, boolean_encoder_mask = self.encode_text(text) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_samples_per_prompt, 0) + boolean_encoder_mask = boolean_encoder_mask.repeat_interleave(num_samples_per_prompt, 0) + elif self.task == "image2audio": + if classifier_free_guidance: + encoder_hidden_states, boolean_encoder_mask = self.encode_image_classifier_free(image, num_samples_per_prompt, device=device) + else: + encoder_hidden_states, boolean_encoder_mask = self.encode_image_no_grad(image, device=device) + encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_samples_per_prompt, 0) + boolean_encoder_mask = boolean_encoder_mask.repeat_interleave(num_samples_per_prompt, 0) + elif self.task == "video2audio": + batch_size = len(video) * num_samples_per_prompt + encoder_hidden_states, boolean_encoder_mask = self.encode_video_classifier_free(video, text, num_samples_per_prompt, device=device) + # import pdb;pdb.set_trace() + #print("ldm time 1", time.time()-start) + inference_scheduler.set_timesteps(num_steps, device=device) + timesteps = inference_scheduler.timesteps + + num_channels_latents = self.unet.in_channels + latents = self.prepare_latents(batch_size, inference_scheduler, num_channels_latents, encoder_hidden_states.dtype, device) + num_warmup_steps = len(timesteps) - num_steps * inference_scheduler.order + progress_bar = tqdm(range(num_steps), disable=disable_progress) + + #print("ldm time 2", time.time()-start, timesteps) + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if classifier_free_guidance else latents + latent_model_input = inference_scheduler.scale_model_input(latent_model_input, t) + + #print("ldm emu", i, time.time()-start) + noise_pred = self.unet( + latent_model_input, t, encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=boolean_encoder_mask + ).sample + + # perform guidance + if classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = inference_scheduler.step(noise_pred, t, latents).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % inference_scheduler.order == 0): + progress_bar.update(1) + + #print("ldm time 3", time.time()-start) + if self.set_from == "pre-trained": + latents = self.group_out(latents.permute(0, 2, 3, 1).contiguous()).permute(0, 3, 1, 2).contiguous() + return latents + + def prepare_latents(self, batch_size, inference_scheduler, num_channels_latents, dtype, device): + shape = (batch_size, num_channels_latents, 256, 16) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * inference_scheduler.init_noise_sigma + return latents + + def encode_text_classifier_free(self, prompt, num_samples_per_prompt): + device = self.text_encoder.device + batch = self.tokenizer( + prompt, max_length=self.tokenizer.model_max_length, padding=True, truncation=True, return_tensors="pt" + ) + input_ids, attention_mask = batch.input_ids.to(device), batch.attention_mask.to(device) + + with torch.no_grad(): + prompt_embeds = self.text_encoder( + input_ids=input_ids, attention_mask=attention_mask + )[0] + + prompt_embeds = prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) + attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + # get unconditional embeddings for classifier free guidance + uncond_tokens = [""] * len(prompt) + + max_length = prompt_embeds.shape[1] + uncond_batch = self.tokenizer( + uncond_tokens, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt", + ) + uncond_input_ids = uncond_batch.input_ids.to(device) + uncond_attention_mask = uncond_batch.attention_mask.to(device) + + with torch.no_grad(): + negative_prompt_embeds = self.text_encoder( + input_ids=uncond_input_ids, attention_mask=uncond_attention_mask + )[0] + + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) + uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + # For classifier free guidance, we need to do two forward passes. + # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + prompt_mask = torch.cat([uncond_attention_mask, attention_mask]) + boolean_prompt_mask = (prompt_mask == 1).to(device) + + # import pdb;pdb.set_trace() + return prompt_embeds, boolean_prompt_mask + + def encode_image_no_grad(self, prompt, device): + with torch.no_grad(): + img_fea = self.img_fea_extractor(prompt) + encoder_hidden_states = img_fea.view(img_fea.shape[0], img_fea.shape[1], -1).permute(0, 2, 1) + boolean_encoder_mask = torch.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.bool) + boolean_encoder_mask = boolean_encoder_mask.to(device) + + return encoder_hidden_states, boolean_encoder_mask + + def encode_text_clip_classifier_free(self, prompt, num_samples_per_prompt, device): + # 如果想测试输入文本的效果,就用下面两行 + with torch.no_grad(): + encoder_hidden_states, boolean_encoder_mask = self.encode_text_CLIP(prompt, device) + # if "clip-vit" in self.fea_encoder_name: + # with torch.no_grad(): + # inputs = self.CLIP_processor(text=['aaa'], images=prompt, return_tensors="pt", padding=True).to(device) + # encoder_hidden_states = self.CLIP_model(**inputs).image_embeds # b * 768 + # encoder_hidden_states = self.linear_layer(encoder_hidden_states) # b * 1024 + # encoder_hidden_states = encoder_hidden_states.unsqueeze(1).to(device) + # boolean_encoder_mask = torch.ones((encoder_hidden_states.shape[0], encoder_hidden_states.shape[1]), dtype=torch.bool) + # boolean_encoder_mask = boolean_encoder_mask.to(device) + + b, t, n = encoder_hidden_states.shape + attention_mask = boolean_encoder_mask.to(device) + prompt_embeds = encoder_hidden_states.repeat_interleave(num_samples_per_prompt, 0) + attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + negative_prompt_embeds = encoder_hidden_states.new(b, t, n).fill_(0) + uncond_attention_mask = torch.ones((b, t), dtype=torch.bool).to(device) + + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) + uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + # For classifier free guidance, we need to do two forward passes. + # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + boolean_prompt_mask = torch.cat([uncond_attention_mask, attention_mask]) + + return prompt_embeds.to(device), boolean_prompt_mask.to(device) + + + def encode_image_classifier_free(self, prompt, num_samples_per_prompt, device): + with torch.no_grad(): + if "clip-vit" in self.fea_encoder_name: + inputs = self.CLIP_processor(text=["aaa"], images=prompt, return_tensors="pt", padding=True).to(device) + img_fea = self.CLIP_model(**inputs).image_embeds + img_fea = self.linear_layer(img_fea) + else: + img_fea = self.img_fea_extractor(prompt) + encoder_hidden_states = img_fea.view(img_fea.shape[0], img_fea.shape[1], -1).permute(0, 2, 1) + b, t, n = encoder_hidden_states.shape + boolean_encoder_mask = torch.ones((b, t), dtype=torch.bool) + attention_mask = boolean_encoder_mask.to(device) + prompt_embeds = encoder_hidden_states.repeat_interleave(num_samples_per_prompt, 0) + attention_mask = attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + negative_prompt_embeds = encoder_hidden_states.new(b, t, n).fill_(0) + uncond_attention_mask = torch.ones((b, t), dtype=torch.bool).to(device) + + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) + uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + # For classifier free guidance, we need to do two forward passes. + # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + boolean_prompt_mask = torch.cat([uncond_attention_mask, attention_mask]) + + return prompt_embeds.to(device), boolean_prompt_mask.to(device) + + def encode_video_classifier_free(self, video_batch, text_batch, num_samples_per_prompt, device): + vid_feas = [] + for i, video in enumerate(video_batch): + if text_batch: + vid_fea = self.vid_fea_extractor(video=video.to(device), text=text_batch[i]) + else: + vid_fea = self.vid_fea_extractor(video=video.to(device)) + vid_feas.append(vid_fea) + + padding = 0 + size = max(v.size(0) for v in vid_feas) + batch_size = len(vid_feas) + embed_size = vid_feas[0].size(1) + encoder_hidden_states = vid_feas[0].new(batch_size, size, embed_size).fill_(padding) + boolean_encoder_mask = torch.ones((batch_size, size), dtype=torch.bool) + + def copy_tensor(src, dst): + assert dst.numel() == src.numel() + dst.copy_(src) + + for i, v in enumerate(vid_feas): + copy_tensor(v, encoder_hidden_states[i][: len(v)]) + boolean_encoder_mask[i, len(v):] = False + + b, t, n = encoder_hidden_states.shape + negative_prompt_embeds = encoder_hidden_states.new(b, t, n).fill_(0) + uncond_attention_mask = torch.ones((b, t), dtype=torch.bool) + + negative_prompt_embeds = negative_prompt_embeds.repeat_interleave(num_samples_per_prompt, 0) + uncond_attention_mask = uncond_attention_mask.repeat_interleave(num_samples_per_prompt, 0) + + # For classifier free guidance, we need to do two forward passes. + # We concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes + encoder_hidden_states = torch.cat([negative_prompt_embeds, encoder_hidden_states]) + boolean_encoder_mask = torch.cat([uncond_attention_mask, boolean_encoder_mask]) + + return encoder_hidden_states.to(device), boolean_encoder_mask.to(device) \ No newline at end of file diff --git a/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-26aVYRtEAc_000030.wav b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-26aVYRtEAc_000030.wav new file mode 100644 index 0000000000000000000000000000000000000000..af97ec705e040e3ca55d54582d37b8261ec21acb Binary files /dev/null and b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-26aVYRtEAc_000030.wav differ diff --git a/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-BAKe6QGTUk_000030.wav b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-BAKe6QGTUk_000030.wav new file mode 100644 index 0000000000000000000000000000000000000000..7f8538c03136c02e99a49226ff364afb69cd719e Binary files /dev/null and b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-BAKe6QGTUk_000030.wav differ diff --git a/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-yoaSondvkw_000071.wav b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-yoaSondvkw_000071.wav new file mode 100644 index 0000000000000000000000000000000000000000..b059147cd8d1a0af970eb62943964d9463a002eb Binary files /dev/null and b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/-yoaSondvkw_000071.wav differ diff --git a/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/0Bp8c3PfAAA_000053.wav b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/0Bp8c3PfAAA_000053.wav new file mode 100644 index 0000000000000000000000000000000000000000..28dad07d903aa7f02065408db9bc3c4a368db072 Binary files /dev/null and b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/0Bp8c3PfAAA_000053.wav differ diff --git a/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/0DCit2EBtjs_000030.wav b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/0DCit2EBtjs_000030.wav new file mode 100644 index 0000000000000000000000000000000000000000..c32e4285f539559b70cc139b1987b2bd781c7a51 Binary files /dev/null and b/outputs/vta-ldm-clip4clip-v-large/1720614438_vta-ldm-clip4clip-v-large_steps_300_guidance_3.0_sampleRate_16000_augment/0DCit2EBtjs_000030.wav differ diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e65908dfee5cc18b886e510db77872ef92e8e443 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,8 @@ +moviepy +datasets +tqdm +torch +numpy +diffusers +transformers +opencv-python \ No newline at end of file diff --git a/tools/__init__.py b/tools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/tools/__pycache__/__init__.cpython-310.pyc b/tools/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edb6d27e0d2c8392c6b7a67b4e207673feabc23b Binary files /dev/null and b/tools/__pycache__/__init__.cpython-310.pyc differ diff --git a/tools/__pycache__/mix.cpython-310.pyc b/tools/__pycache__/mix.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4ed589f77fcdddcfe53f9b572d0c65caf3f26758 Binary files /dev/null and b/tools/__pycache__/mix.cpython-310.pyc differ diff --git a/tools/__pycache__/torch_tools.cpython-310.pyc b/tools/__pycache__/torch_tools.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d31d063214527b998de8a4f1e9cf074434a3e1ad Binary files /dev/null and b/tools/__pycache__/torch_tools.cpython-310.pyc differ diff --git a/tools/base_config.py b/tools/base_config.py new file mode 100644 index 0000000000000000000000000000000000000000..d494c9ae1697ef1d1fdcf34d3b3b30546d06ae1a --- /dev/null +++ b/tools/base_config.py @@ -0,0 +1,135 @@ + +import os + +def default_vae_config(): + basic_config = { + "model": { + "params": { + "first_stage_config": { + "base_learning_rate": 4.5e-05, + "target": "audioldm.variational_autoencoder.autoencoder.AutoencoderKL", + "params": { + "monitor": "val/rec_loss", + "image_key": "fbank", + "subband": 1, + "embed_dim": 8, + "time_shuffle": 1, + "ddconfig": { + "double_z": True, + "z_channels": 8, + "resolution": 256, + "downsample_time": False, + "in_channels": 1, + "out_ch": 1, + "ch": 128, + "ch_mult": [1, 2, 4], + "num_res_blocks": 2, + "attn_resolutions": [], + "dropout": 0.0, + }, + }, + }, + }, + }, + } + + + return basic_config + +def default_stft_config(): + + basic_config = { + "preprocessing_16k": { + "audio": {"sampling_rate": 16000, "max_wav_value": 32768}, + "stft": {"filter_length": 1024, "hop_length": 160, "win_length": 1024}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 8000, + "freqm": 0, + "timem": 0, + "blur": False, + "mean": -4.63, + "std": 2.74, + "target_length": 1024, + }, + }, + "preprocessing_24k": { + "audio": {"sampling_rate": 24000, "max_wav_value": 32768}, + "stft": {"filter_length": 2048, "hop_length": 240, "win_length": 2048}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 12000, + "target_length": 1024, + }, + }, + "preprocessing_32k": { + "audio": {"sampling_rate": 32000, "max_wav_value": 32768}, + "stft": {"filter_length": 2048, "hop_length": 320, "win_length": 2048}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 0, + "mel_fmax": 16000, + "target_length": 1024, + }, + }, + "preprocessing_48k": { + "audio": {"sampling_rate": 48000, "max_wav_value": 32768, "duration": 10.00}, + "stft": {"filter_length": 2048, "hop_length": 480, "win_length": 2048}, + "mel": { + "n_mel_channels": 64, + "mel_fmin": 20, + "mel_fmax": 24000 + } + }, + } + + return basic_config + +def get_metadata(): + return { + "audioldm-s-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full.ckpt", + ), + "url": "https://zenodo.org/record/7600541/files/audioldm-s-full?download=1", + }, + "audioldm-l-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-l-full.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-l.ckpt?download=1", + }, + "audioldm-s-full-v2": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-full-v2.ckpt", + ), + "url": "https://zenodo.org/record/7698295/files/audioldm-full-s-v2.ckpt?download=1", + }, + "audioldm-m-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-text-ft.ckpt?download=1", + }, + "audioldm-s-text-ft": { + "path": os.path.join( + CACHE_DIR, + "audioldm-s-text-ft.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-s-text-ft.ckpt?download=1", + }, + "audioldm-m-full": { + "path": os.path.join( + CACHE_DIR, + "audioldm-m-full.ckpt", + ), + "url": "https://zenodo.org/record/7813012/files/audioldm-m-full.ckpt?download=1", + }, + } + diff --git a/tools/get_audio_from_video.sh b/tools/get_audio_from_video.sh new file mode 100644 index 0000000000000000000000000000000000000000..a6c5952bb6cc08a59afb76056a91674913a01cb9 --- /dev/null +++ b/tools/get_audio_from_video.sh @@ -0,0 +1,10 @@ +video_paths="../data/video_processed/video_gt_augment" +save_dir="../data/video_processed/audio_gt_augment" + +# get wav audio from video +for video_path in $video_paths/*; do + video_name=$(basename $video_path) + audio_name="${video_name%.*}.wav" + audio_path="$save_dir/$audio_name" + ffmpeg -i $video_path -vn -acodec pcm_s16le -ar 16000 -ac 1 $audio_path +done \ No newline at end of file diff --git a/tools/merge_video_audio.sh b/tools/merge_video_audio.sh new file mode 100644 index 0000000000000000000000000000000000000000..a5cdcde2dabda253a10a28cc653aad879309f2d2 --- /dev/null +++ b/tools/merge_video_audio.sh @@ -0,0 +1,24 @@ +video_folder="../data" +audio_folder="../outputs/vta-ldm-clip4clip-v-large" +output_folder="../outputs/merged_video" +# output_folder="outputs/merge_video_youtube_example" +if [ ! -d $output_folder ]; then + mkdir -p $output_folder +fi +# for video in $video_folder/*; do +# for the first 30 video files +for video in $(ls $video_folder | head -30); do + video="$video_folder/$video" + video_name=$(basename $video) + audio_name=$(basename "$video_name" .mp4) + # audio_name=$video_name + audio_name="$audio_name.wav" + audio_path="$audio_folder/$audio_name" + echo $audio_path + if [ -f $audio_path ]; then + echo "Processing $video_name" + ffmpeg -y -i $video -i $audio_path -c:a aac -map 0:v:0 -map 1:a:0 $output_folder/$video_name.mkv + ffmpeg -y -i $output_folder/$video_name.mkv -c:a aac $output_folder/$video_name + rm $output_folder/$video_name.mkv + fi +done \ No newline at end of file diff --git a/tools/mix.py b/tools/mix.py new file mode 100644 index 0000000000000000000000000000000000000000..07c4cf6ba2a0c899bd53387dd6b116f46daa34a6 --- /dev/null +++ b/tools/mix.py @@ -0,0 +1,57 @@ +import numpy as np + + +def a_weight(fs, n_fft, min_db=-80.0): + freq = np.linspace(0, fs // 2, n_fft // 2 + 1) + freq_sq = np.power(freq, 2) + freq_sq[0] = 1.0 + weight = 2.0 + 20.0 * (2 * np.log10(12194) + 2 * np.log10(freq_sq) + - np.log10(freq_sq + 12194 ** 2) + - np.log10(freq_sq + 20.6 ** 2) + - 0.5 * np.log10(freq_sq + 107.7 ** 2) + - 0.5 * np.log10(freq_sq + 737.9 ** 2)) + weight = np.maximum(weight, min_db) + + return weight + + +def compute_gain(sound, fs, min_db=-80.0, mode="A_weighting"): + if fs == 16000: + n_fft = 2048 + elif fs == 24000: + n_fft = 4096 + elif fs == 32000: + n_fft = 2048 + elif fs == 44100: + n_fft = 2048 + elif fs == 48000: + n_fft = 4096 + else: + raise Exception("Invalid fs {}".format(fs)) + stride = n_fft // 2 + + gain = [] + for i in range(0, len(sound) - n_fft + 1, stride): + if mode == "RMSE": + g = np.mean(sound[i: i + n_fft] ** 2) + elif mode == "A_weighting": + spec = np.fft.rfft(np.hanning(n_fft + 1)[:-1] * sound[i: i + n_fft]) + power_spec = np.abs(spec) ** 2 + a_weighted_spec = power_spec * np.power(10, a_weight(fs, n_fft) / 10) + g = np.sum(a_weighted_spec) + else: + raise Exception("Invalid mode {}".format(mode)) + gain.append(g) + + gain = np.array(gain) + gain = np.maximum(gain, np.power(10, min_db / 10)) + gain_db = 10 * np.log10(gain) + return gain_db + + +def mix(sound1, sound2, r, fs): + gain1 = np.max(compute_gain(sound1, fs)) # Decibel + gain2 = np.max(compute_gain(sound2, fs)) + t = 1.0 / (1 + np.power(10, (gain1 - gain2) / 20.) * (1 - r) / r) + sound = ((sound1 * t + sound2 * (1 - t)) / np.sqrt(t ** 2 + (1 - t) ** 2)) + return sound diff --git a/tools/show_audio_spec.ipynb b/tools/show_audio_spec.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..b265c2ae1c645c84d0a28ea1567c3858667c0217 --- /dev/null +++ b/tools/show_audio_spec.ipynb @@ -0,0 +1,591 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from scipy.io.wavfile import read, write\n", + "import torch\n", + "from librosa.util import normalize\n", + "from librosa.filters import mel as librosa_mel_fn\n", + "import numpy as np\n", + "import librosa\n", + "from IPython.display import Audio\n", + "from tqdm import tqdm, trange\n", + "import os\n", + "import matplotlib.pyplot as plt\n", + "import cv2" + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": {}, + "outputs": [], + "source": [ + "MAX_WAV_VALUE = 32768.0\n", + "\n", + "def load_wav(full_path):\n", + " sampling_rate, data = read(full_path)\n", + " return data, sampling_rate\n", + "\n", + "def dynamic_range_compression(x, C=1, clip_val=1e-5):\n", + " return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)\n", + "\n", + "def dynamic_range_decompression(x, C=1):\n", + " return np.exp(x) / C\n", + "\n", + "def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):\n", + " return torch.log(torch.clamp(x, min=clip_val) * C)\n", + "\n", + "def dynamic_range_decompression_torch(x, C=1):\n", + " return torch.exp(x) / C\n", + "\n", + "def spectral_normalize_torch(magnitudes):\n", + " output = dynamic_range_compression_torch(magnitudes)\n", + " return output\n", + "\n", + "def spectral_de_normalize_torch(magnitudes):\n", + " output = dynamic_range_decompression_torch(magnitudes)\n", + " return output\n", + "\n", + "mel_basis = {}\n", + "hann_window = {}\n", + "\n", + "def mel_spectrogram(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):\n", + " if torch.min(y) < -1.:\n", + " print('min value is ', torch.min(y))\n", + " if torch.max(y) > 1.:\n", + " print('max value is ', torch.max(y))\n", + "\n", + " global mel_basis, hann_window\n", + " if fmax not in mel_basis:\n", + " mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)\n", + " mel_basis[str(fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)\n", + " hann_window[str(y.device)] = torch.hann_window(win_size).to(y.device)\n", + "\n", + " y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft-hop_size)/2), int((n_fft-hop_size)/2)), mode='reflect')\n", + " y = y.squeeze(1)\n", + "\n", + " spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[str(y.device)],\n", + " center=center, pad_mode='reflect', normalized=False, onesided=True)\n", + "\n", + " spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))\n", + "\n", + " spec = torch.matmul(mel_basis[str(fmax)+'_'+str(y.device)], spec)\n", + " spec = spectral_normalize_torch(spec)\n", + "\n", + " return spec\n", + "\n", + "def show_mel(audio_file, save_name=None, sr=16000):\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n", + " ori_wav = np.clip(ori_wav, -1, 1)\n", + " x = torch.FloatTensor(ori_wav)\n", + " x = mel_spectrogram(x.unsqueeze(0), n_fft=2048, num_mels=80, sampling_rate=16000,\n", + " hop_size=128, win_size=128, fmin=0, fmax=8000)\n", + " # hide x axis\n", + " plt.xticks([])\n", + " # hide y axis\n", + " plt.yticks([])\n", + " # more clear\n", + " spec = x.cpu().numpy()[0]\n", + " # reverse y\n", + " spec = spec[::-1,:]\n", + " in_mel = spec[:,:]\n", + " # in_mel = spec[:,:624]\n", + " print(in_mel.shape)\n", + " plt.imshow(in_mel)\n", + " # save as pdf\n", + " if save_name is not None:\n", + " plt.savefig(save_name, bbox_inches='tight', pad_inches=0)\n", + "\n", + "def show_melfb(audio_file, save_name=None):\n", + " ori_wav = librosa.load(audio_file, sr=16000)[0]\n", + " ori_wav = np.clip(ori_wav, -1, 1)\n", + " x = torch.FloatTensor(ori_wav)\n", + " x = mel_spectrogram(x.unsqueeze(0), n_fft=2048, num_mels=80, sampling_rate=16000,\n", + " hop_size=256, win_size=128, fmin=0, fmax=8000)\n", + " # hide x axis\n", + " plt.xticks([])\n", + " # hide y axis\n", + " plt.yticks([])\n", + " # more clear\n", + " spec = x.cpu().numpy()[0]\n", + " # in_mel = spec[:,:]\n", + " in_mel = spec[:,:624]\n", + " plt.imshow(in_mel)\n", + " # save as pdf\n", + " if save_name is not None:\n", + " plt.savefig(save_name, bbox_inches='tight', pad_inches=0)\n", + " \n", + "def show_video_frames(video_path, frame_rate=1.0, size=224, save_name=None):\n", + " videos = []\n", + " # for video_path in video_paths:\n", + " # cap = cv2.VideoCapture(video_path)\n", + " cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG)\n", + " frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))\n", + " fps = int(cap.get(cv2.CAP_PROP_FPS))\n", + " if fps < 1:\n", + " images = np.zeros([3, size, size], dtype=np.float32) \n", + " print(\"ERROR: problem reading video file: \", video_path)\n", + " else:\n", + " total_duration = (frameCount + fps - 1) // fps\n", + " start_sec, end_sec = 0, total_duration\n", + " interval = fps / frame_rate\n", + " frames_idx = np.floor(np.arange(start_sec*fps, end_sec*fps, interval))\n", + " ret = True \n", + " \n", + " for i, idx in enumerate(frames_idx):\n", + " cap.set(cv2.CAP_PROP_POS_FRAMES , idx)\n", + " ret, frame = cap.read() \n", + " if not ret: break\n", + " frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) \n", + " videos.append(frame)\n", + " \n", + " cap.release()\n", + " \n", + " # concat all images and show in one figure\n", + " # videos = videos[:10]\n", + " images = np.concatenate(videos, axis=1)\n", + " # print(images.shape)\n", + " plt.imshow(images)\n", + " if save_name is not None:\n", + " plt.savefig(save_name, bbox_inches='tight', pad_inches=0, dpi=900)" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3407931/2796252854.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_augment_2/1719298626_tango_video2audio_clip4clip_augment_2_best_steps_300_guidance_3.0_sampleRate_16000_augment/-BAKe6QGTUk_000030_V6-xlXRhkI0_000050_7.wav\"\n", + "show_mel(file, save_name=\"1_wav.pdf\")\n", + "show_video_frames(\"../data/video_processed/video_gt_augment/-BAKe6QGTUk_000030_V6-xlXRhkI0_000050_7.mp4\", save_name=\"1_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3407931/2796252854.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAgMAAAAvCAYAAAB+KskzAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8WgzjOAAAACXBIWXMAAA9hAAAPYQGoP6dpAACJ7UlEQVR4nOz9Z5SnWXbWif6Oed3fh4/IjLSVleV9VXe1b6lb7dTIohYgCSchMdJg5oo7GGGEuczAYgAJEAKBhJAQamSQ65Ztr+qu6vImK70Pk+Ej/u41x82H918lcWexBtZt7nxQ7rXqQ2VVRL7mvGfv/ezneY4IIQRux+24HbfjdtyO2/EHNuT/0xdwO27H7bgdt+N23I7/Z+N2MXA7bsftuB2343b8AY/bxcDtuB2343bcjtvxBzxuFwO343bcjttxO27HH/C4XQzcjttxO27H7bgdf8DjdjFwO27H7bgdt+N2/AGP28XA7bgdt+N23I7b8Qc89H/L/+S9Z21tjXa7jRDif/Q13Y7bcTtux+24HbfjKxAhBAaDAYcOHULK/3r//99UDKytrXHkyJGv2MXdjttxO27H7bgdt+P/f3Hz5k2Wl5f/q//9v6kYaLfbAHz3+56k1WggA+RVhVISKRTWO9IkRiuB9wHnYGdvj9PLU3zfX5on7EpQJbgR5taY62sFF28aOovTHGyP6Q8r9vYL+oVGasfGvuLIiTvodRps7vQp2nN8/Td8HeNRzmNvvQ8RQEgJSIIUBPSbiIUXAhCEIHHe40LAGot1lvVrV/ilH/qHBOspKoikQEeSwgbajQjvHTLSWGMpxw6hFQ8clnzLR5scbIyQUmBNQb8/ZmM/sDeAvX7J1FyP8SBHIqhcwLhA6etreP6y59CRRdJYs7W1g5s9zAfe9yHShuLxtzyMFxKlBCIIkBIICFH/rBcCgcCJ33tNznv+0ff8CaypKzytNc454kQiZYTWnhBqNGcwqGi1Yh6+B951OmZ/vyI4T3AFg3HJbt+zPwoUY0teSbrTKS4vSJoJxnpirTBeYW3Bp1/MuTI0JO0uD957H9Vgj3R6gbc8+VaaiebRxx9CRwovQQBSKQT1OxFoAgGBxEkIQhC85h/92T/z5nuz1pIkMUoFhAhopbAWrLPkeWCqF/OeByumGoHBoCRUFbmp8AF2+x5bCUoryUtLUDDdaeJsSTOJcC6QpjE+gFOa8WDErz67w4VBII5innjnu/Hes7V6gw9+4EP0em0aUcqRYwvMLM5NrlEg5Ruo2O+hY7/0o/+S1atXCcERAjgLWgeMDWRZjEIglUQJSV6WaKkYWYUKFi2HvPWuBoKS/mCMtRUgqCpLWQZKIyiNRymoXGB+rokUGhkscaQI3hNpAShkFGPyPp+72ubS2jaFg06ny3SvS5TGvO9r3sfXvOcdTDUauKoiiAhrHPnaKvsr1/m3v/gLEARKg7UVcazxxiJ0io5ACYEQMQjDaGQQQiHiFna8T6RG3H0kQghHf5AjAiA8QThcJcmLAuvAOIeUEu8FszNtdJySD7bpdlrI4FFSEkWKIjRIGGGd42p+mM++fJ2qcljrSSOBM5Y/+73fxfsfuRfdnabMWgTg7Pnr/Ppnnua59U9z4DZYXJzlZ//2L/Fj/+5z5HmfUZmTJZ7SBqSwBC8prSGhwOgO3W6MCnD46AL9/T5zS4u0shZRnCClwHuPlAohJFWZMz44IE4b9PsDBgd7LB89QaQVURTjnWVu7jCNRgdpKqQAJQrGhacwktN39Pj/xlit8exulvzdH/kRKj9k377K9f2LfOzbv5m3xt/Al1+6TD8fE0RBqiOKqkT6HBG38MKiHJTes7w0S5ImKAwuSA4fO0GqNTqKCBK8dUQqwZiCfDykykvirMHBwQG9TgelE7yxtFotWq0maaNDI2mibAkaIiyXr2+xuLDA0cPt/+IevA/YKrByc4cf/al/x7ntLzNQ61g35l//4L/jM59YYTA6YFDmZMqDTnFuQITkoLC0M4UNkqwd04wSlg7NYowjjlOmp2eItCaKIipnQUpUCFjjGA8GICW7ezsID7MLCygkkY5JkoSs0aTRmEIFh3IWpR1SODa3PHfeOUcS/V/fxXBg+aEf/0lGg13O7n2OW6MV3vbeh/nbf+hH+fH/9NsMx0O8G9OIBcPSkMYC5zVBeHw5QmUtZqbbSCHo9VrcvLnKY48+TpzExEmCkApjKqSUhCDIhwO2N9eYnT/Ezs4uMgSm5+YJVUXWbKMQdKfnaTc6SFsgNAhTsr4zpJ11uONk77+4hxDAGseF86u8890PvZnH/2vx31QMvLFhR0qRRBopBNY5lJIopUmEJkniyWJwBC1I45hEB9pjR3lQIpWlHBf0tx2jocAZ0MGiMcw0FVnUYMrEeF+xfWAJPrB/MMQ7g04znAvMzPRoNVt1sgSQkiBAhoggICAJkwTqBXgvccFjjcE5z0GzRaoUVRAkSUDJCIkgiwNSgVIJOhIo7/EyoGNJ1vCIMhCcxHtHUTjK0lOWGucMEBAuoKSjGWc455BKYwkc5IHNUZ+rF67y5BNvgV7g0MIySTMhiSRZlhGkJISAUgqlFCEECAopBULoutgRiiDAB0XlqvodBIGUAa01xkCWKAR1JlZK4b2niAJprGk2NFrUSdYHR2k8PkDwEm8NUgm0FiRaYWJBI5F4KUlTTWkFVaWonKPyloaOEO0uo2HBoV6PdtpAR544TdCRQkiBEBIlNUIGCHWhJmQAFE5Nll2QJJF+c31pKcjSGBUsIlJIIamkIfEJrjKkkSYRDiEkthoTRwppI6zxOO8BgQyOONL44NFSEfA04pgQPHGicEHhpMdHmiTNEMMxzWaDKNLgPceOLdPrdTm+fIhYQhoCba0JSYyQv2/rDhohQEhBrBMUCh1LrIFIQZAGVCCSmlhJvPNESYRzluAjWo0WJt+jmSiEDHgnUEpjjCWWAgsgBUJJlPNoAU4ItAikcQIBtBIoBErXzwoF2sbkViNUgvSealxwazTm7ntO8rGv+zDtRkaeV+g4ZW97j+3NXcrCks4vkyYRKgSkFpQhEEUKgsAEiFUCeJQAITXEgdwI4jQFo4lVTKxAEBPpCikEUjisCWgtcVoDZrKPBAyKSCuyOIWobiK0ilEEYqXQMsHZIYKUrWG9npxztJpNsCVpM6Fcv0X7wfvZH1m28wpUxCj3bG6usNTtspzNEC95/uO5f0NZHkYkGXGiSBoaJWJaDY2UDcpgmNKOvcKTaUsiAgcH+6RZxqULF4miCG89cRpjjaPbTFEyZmFhkd7MPEma4lxg+cgRpnpT9bfnPPlgQKPZYabbZLrZI4okCE9VSa5e26Hb7hDqqgmCxHvPcN+Tl2uM9td55drv8LFv/haaN1Je23+Ok+JdCDxxo4HSMVnWJlEw3YwZ2RgdFwgjyaucTAuSZsLm6lUefvwtnDlzHm9LdJRQmZJISJrNNt1uj2anw0xnGoQiBMHxkyeQUuMqg5QSJRSNZo+5dkq3PY2YoMyjXNFqJL8vwdTfx2jg2B3s45xnd3uD5eU2wbVRbfiHn/sB7i//OEppsmaDRhrj4pQsnaapM9rO0UsVu4UlExWR9xRlxf7BDsErNrduYU1FFDUYjg7oNFskcczS4jLd6WmiOKMoK2bn5piamibRMcE5gvOoOGOqlTHdymg2xZvXa8wurUaTNJVvJlCCYDS0DA42GO9v8/ql5zh0aoFOo0Hf7bE93MTZuoBCKLI0RcYRU82EymcIOaYRPAPj6KSKNE5YW7/CO9/1JC++cAapPF4IIqkAaCYpaaNJFCdMzSzR6c4QPHR6PXq9aXCubjTHBe1ml5luk5l2j/rHBXHWpyrN/+VdVIVndzAmBP9f5PH/n4qBN0IEQfAei4DJQws+oCKN83VlDwLvPBAQeEabJbs7OTIYhsMh2/uC3b4lLy3FqKQ0FZlOyLRARwJjPAeDEetXbnLf3XfR3x9yZLpLkRcsLMxgjEEIiRASrzwChceBnBQIQk6Khfrf65MXJitYCEzwhMmmRH0LOOuAGIIkGAsBTHBop5CyYjguGZcWZ3KGuWU0DvSHFYVxeCux1iNcAOFRStUJ2YOvKgrvGBSBQjc48H2OJDF4gXMOay1eCKSUCAHBT95jCIRQb6BIQVACD7jgCUG8+WKlVAjpJotYIYUjUP9u7wUqqLoj957RyDMaO4ytMNZS5FCUjqqyuCDxAqzz4AQ4Qf0mRX1dwWOsxTpHHDew3mNEqFGU4MEI1ldusbQ8j1ASKesNUSoBeAj1hh6igDAxUorJRgh4j1SSICVlWZGmEiUEYfKKXKgLPmcsAwOqlJSVpjAVVSkwJlCVNRrjg8T5gBRQGI9z9bwMwFqHiBT1H0Yc5BXBe+IoRtkxURRz+tS9fOA9b2O21UZOTSFWrjJaX2ffR5iFXv3xU69trSOCD7gQQDiC14RgQNYfnhYTiESAExrhHcZZBOr3sXYdRWVw1lO5QGE9ToEBjAk4Ai7U92ECOOvxMRAmz1eKeu3qmGBrlGScV3ggz0ta010GB/s0223sxg7V3CwmSRkNCnIpiWfnOdjcpQwGKcF5i7MBgUYg8dIjvECGgMdhJxgP0lOaQCgrtoc5vYUpRmaPVAXQGTpt0Gw2KSrL9tpVjPP1bhA8eIn1BucslTNEUYwgwk7emxUSJSWD0qLjmBtrWxhjaTYaaB2oKk+aJSy1Orz48lnSE6c5v3GT/WHB9tYmSdpCNftEjQbWHnC0d5KV8SalsyAD1U6JV5o8TQgiJuBQ7ZQximy6Tdpo0sgUadrgyOHj9fonMBqPuXHtGnY8xvkxq8MDdGeG4WhEJ8uo8hwtEppZgzRKmF06BDJhtpuhZJjsNzUKaK1hdXUPHwLOSKJYUxYF+ShHhACRJklanL96gXanw9LcEtVqxWC4T249WlWM5TZeSoZKEnRGnELwiqzdxPmIXreDlstUpeGeu+5FKTDWsrGxwdq1i4iqQJU5W2tryLRBcI7lw0dYuXaTbm+adtIg62XMdDqURtDrRryxbQYgTVPGwxHr65JgBcY5klgx6I8ZjcfoCGSiGY1LluYX2K62ODF9B2Z1zDAfYfGUKlCFQBSntOOEEQ6bpZS6QdQUdFodpjsNinzAnXfeRxTHmMpgjOXqtYv40ZDh4IDNomC3O8NoNCLSEd5U4AVTrS5JlNDrtYjjjHamaDZEnRQn9xEwrKxtkca6blCVwofA/t4BxpYQKZqtHiIG6SQnFk5hy8Bo1Kd0ASVKRsEikwaFUhg0jUxRKYVJYpppi6zRYHZmiioveOihRwCHCY793W1u3bxJVZWE8YiD0hDSJlcvX+bQ4hJlUaJlTFPFqEbK0eUlChMx140njVa9v2RZzKA/YGVtD+9B2ECUKMbDiuFwOMmE//fx31cMTKB5H0KdUIUAUaMBAo0L9UYZkJNuT7M/sPSHFd5WDHPPwcizP/JUNlAWAe9rdFwpAb5OPpX3XNve4Wjao2gaGkmGqaoanqvMJOFKgpdvdhtCagLuzWpJyLocCUHgqYsDEQQQIfCIN5KRqJOplOBtwIuAkpLgA0EYXIgYFoI8DxjjGJYwGAXyylEZKL2gspNEDiDqLt94g6kCZWWwXuO9x3tQSUxVlkhixkWBjiKUkngv65c2Kd4EDqV0vfHjYFLggJr8RR7vBdI6hFBoLfC+/mGlFODfLCzw0M89eQm2chQmUBSBUW4oLZMO0OBsUhdM3pNIjQuyRihCQEiPNbbuGJ1FC0mSpiBgfnaOOAiSEAimwkqNkwr55vvxgAMvQQaUSvAEwBGkJ1ICUwWSJCFQvw8lJIZACB5LIEjBoIiJsRTGU7r61Rnr8UFTGlcjHx6kDgihJ/fikEGDCAQr8cZhhcYnbRY0VHnJjfOX+XPf/328/eGHaUcxVV4w3l9hdGuHva1tBpXD7/c5dO+d6LiGi40p6839zY/DoSM9geCBAFLEOFcilUD4+p58AJ1l2GIfGRSmCjgnsAZcUAjnCKFGB5y1BEK9dpynshnZ5NO2BIQXaKUwpHgzBg+jsiKEQJqmiMlavHHuHG53lw0bGGQ9ojhle7tkUBrOv3yeO+49BSiU8gTnccIRfFQn78m70wiCqq/FOoeUCQJoJRHj/Vvc9bXfyr1vfQdJc5as1SVOEg72bvHP/qevx9mABwgSZ12N4CHQEoZVOSmIAamJRY14iRBRuIjdg22MDYAh0ZoQAlI6nnr1Fe5dOo4cC3x3CnzJ5sY6G9trrF25xohtHnrgfr760NfyBf4dVkoQgW67Rx5pkkYXKyCLFcI7YhVoaIkSga3tTQ4vL7O3u4kpDTKJKYuChaUlDi3MUxmPL0u607Ps7vTJoggpoZE0WOo06XZbtHsNnnvuAr3sMO2pBtZAPizYPSgxtmJzZ0SkFUomOG8py4rhuE+apeysrrHUupOXnn8F0bT8sz/xT3np5U2cF9jg6rFFpIjiNq20QSUFWRawRhLFNTRujOVgdxsBaBER0Bhv0QqO33U37WaPLEmIk4xgA8PhgPnFRTCOJGpw4tAMrU6TOIn47Odf4sj8w3gfqArL7t6Y8ajC2pKNzYIkySCAdzE7u1vEaUJwsLG6St8OePmll5lamONf/8BP8ndf/AmEUoSgSBoJkRbIbAohoJtGxEpAEHQyQRwnrK9e49DhQ+zt7kKw+FA3ntNT80Tzh2m1WgjnmJ5bor+3x/7ePnNzs2Rpm6lmg9mpFr2pNv39EVcu3OChh+/GWk+RW3b3cqqyIM9L9pyj2WiRJAnOOfYH+/Q6bbbX16ASPP/cCzjtePfXvJPN1QO8B0dAaU0jaVLplLTVQQtoZhJhQSiL8PU3o4Vn72CfYIcIJUAFyqLg0JFjpGmTdqOBrQwhSIy1HD60RH+/Tytrc6Sd0Jrq0GgmfP5zL3Fi8VFCCFSlY39/zO5+hTEVW1tj4iShykua7TbD4QBjSnww/035/b+rGPj9s9IaxgYImMoRNSImn3y9iQRBwDMcOQa5x5QVRQGDsaG0EuehrBzBqzqRygB4BI4oSqj6fayzlNYhpaY0Fetr60zNtHD1TovwUf0zQtTJXdVQF1IQvEDICez+RhkoZD2fweKcnHTNYdJlS4LwRFJhnas7O8AYQX9Qklee8dhRlJ7+oKA0CmclLtT3LyedvZB1UhYBBgWMy5KgBMLbGgLVmjiOSSLJeHeXRm8KkqhuIkV4cz4NNdTvhUdG9bNEpoBDobHCEbB4VI2EeEUIdlII1F19iDRgGeUwVL4uAsaC0gTKypGbQGEmhZCfJFLqytkFT3AOKSOkEJQuQmDxAqiGNNOIxfkFHr3/bo7NziCUxuzsMNxaZWw8bmGRRqdJCG6yTiRSBoIKOMSkePKEELDeAQEfTD3y8Z4g6sThnAIEtqwIJmYsDLn1VJVFhUBeeionMNYhpMc70EmMlBIdJNaDEBbQ+OCw3mGEpDL1DFtrRf9gxLseeZCk1SLPLfvFmOFgRJAaO3sYZSyD8YjhyOHHlnajHj0IIfGhmiRM9WYxIiQEWT9HEcQElZBIIRjnFS7POagqkm6XwuwBkDuJEyk61lT9A4wpCSHgXCBMijwXPM4YpLaTUqp+z8iYwnokGaXpU1YWpSK88yRJws7WLs+98ApTDzzClZ1tRnnBaHeP1vwc26OcvVfOA4bg6+LZBVAhTFCoGpXywROsnxTUAqEVUkracUSWRZw7c533fex7CVIipEZIx8ziEYRUOFGPvgKmLmxROBewvob3S+vxwSPxRJWnreqCwQTBwcE+VmiScoR3WY1ZSMm1vQOW0n2ki8iHY1ZWb9EaHDDV69FsPIDSmsenH6bKS9auXab0BlzFXqoRaYZOGkjdoN1J0UlCo9XCR5rWbJdy3MMay/TMLHGUkJeGoshRWnMwHKGQ2PEQJSJ6rRmW5qZoNWL644rDx+cnRX3g8OICm3sDtvdG7GwP8SiqsqAqBtj8gLKqIIpI0xRnPFoLnGtz5PAyZVnSTN5FnCZ86OQ38cnrf4mNwZBxVRAnETpSKN2g3W5h45huK0EnLTq9Bo2ZacBwcLDH8tHjTE/NsN8v0FozGg3reTWC8XgfUxVEusn89BIL3R5zsz2Gg4LpuW69hwRotlJurmyxtz/AloLSOkb9PtKPsNZQ+ZqjFAQo72m1WqSNJosLC3TzLkfVcWbaXbpxl9WrF7HeUfmK7ShCpSlR1kFrQdbukGUpUZqi23MkWcZBpHDOMjMzR5rE9IcFzjmcc/U+5RxVfoDaUWjV5IF77ieLBOiEI4dmiCYwXJqmoBVXb6yzuzkmCE2V5xzsbaJkxbgsieIGcZpiipzgHHjH0uIiG1ub3JO8E60lHz31R/jsF55ic+Ua42pMkI4sSXA6odvu4HVKuxWRpR0anZik3aDRbLB1a8RSd5apxVmMrfPjKN8j1i0AinJIWeSk8QzT3QWmWm1OLR9ma2/MwvIcTJqA3kyX66u3GI8rhv0K56EqClw5oKpGFN5PxlopvippZNkkt/7fx38nMjAhUgkQAYL3CKlA1LPLEMQEGRUIKXAhcDDwjAtPPqZOpHmgqCwuQGUcgoAPoFw9n3oj0VtrEa5CR5pGu8XhxUUOz87QLioKY8jTDJW+AbhKhDB1FaYAH2oUI8gaIPECPxkXKOVxNqDw9XULMdmeBF4EpKw73yBABE1lA6NcMRoZykJRlI6xgXFZd7DGeZyQyKBwXqAIBAkEhxMZadwkt47+5iqRbnLq+FHe99hDRM0MNq5xbXULOTtF2soI/o2RgQQMPkhErLGmqjsnZwjCIkVABFACIiXITf1hiDcoeyHgg0eIGk62Vcy4qNgflATrcL7uqEsLlbGoCBAe6+Wb81mHIsKBiLBB4pIuc1NNBhtbXB/1iZKYb/rIX6XVauEdjPpj9sqKAS2KfI/y+hrHH74XW+RvogNSSoSXYAVeyhpZ8h7j3WScowCHcwGFI3iBdbYmUwaL8yX52DIuPaEKGGHwHox1WO8J3oMXaJXgRagTZgAZAs6DEnXRZdwUgTFFUTI11cGYEreywe60R3anGVjBXiF4/dVLHLvzJM4Krm0esFNdRkrN4w8dnxS89dimJsXVRSlqAguIyYjC1+vMWUflLDYodAiMxgU3dgJf8x3fSndqnmxmkc7UHI12mx/+3q/FmhoD9ELgrCFKIkBR+hJKSSQ1QUmEFshQF6AmaMZ5BULSbMQYU9BIUzb3LVdu3KRtEsruDP2y5Oq1a4TVVQSBztwCo2GBR4APGOeRyiCcx4mAGNfEYCZrapyPmZntURQjmlIwHDvOfuZ3+d6/I/HO4cp6BJO2UoaVYFwIOkkNUhkXiJN6BLK3OySKJEmsiLUkjhroSGCDIUiNsRmV8zQbioYK9NKIJI5I4wSdNdgpCjLdJ+t00cYS4Tly8nhdEArNoZNLjArDI48+gvUGVD2qMM7jrKUoKgiesj+gPxywdWWIee4FQhiTxZJub5q00aYoCtI0odPt0ev1UEA7y5id6ZHoFqdPLSCF4OqNvUlRLShyx8bOHtsHu2xv3KLbnSZOExpZRpQ2iGZbWGsIHnZubZOKCKxnOBixcGQREDgXaEQpIcDp++9nsawwtkTGNWnYGgjeMRyP8IVhNN4h39ti7dJV+nubHFme56kvfJFmq4kIMWhBkmi6nRk6nQ74gqW5DvPT09x95x1kac2i6w/GeB8wJrB7MMIYz8tnz9Hf3Wdufok4UrR6KVJkoEC4wHg4ZG1lg9luB1t5Dqo+y8cPYZ1HipiZdgfvPQ8/9hjWVzhnUVpjvMO6gK0MpbUUgz55/4AX129SFCUCg8AyNztLnHWwZUXSSGk0mvR6PdJIoKSi3czQJNx76hBCwO5+gXMeEaCsPDdvbLLdH3DmwkV0kMzOzpEkMVOLsygp6NgSISMG/UENs+uIoiyZXpgh7jRwzhOriDhL6MzM8MDDD+KcRURgncPaGgEY5TnSBvJ+n/F+wcbVqwyGQ7LYc/ncdabmppEirdH0KDDdm6fb7SFCyezULIfmplmcW2ZpoQ0C9vZznA9Y5xmMDKPKcPPSZQbb+0zPzhApTZKlpK0ubdGFAMPRiO31bXrNDFxgMBp95YsBKeukb0xZM6vFG7OwelEKKWu2+qQgcA4GuWUwcpgyMK4847ImsLkgqKpAFEmsrWH9em4u2RyWNKKU3dVVRBTxwa96F4vTHWzh2F/fYrizx0AfsHDPHZiqnCADqv4dHpASIVVdEEiJDwIfAgSBFBEuVKjJxiCFQvhAcB4Z3riXyfwZjzOe4dgxGFucdZSVxVjxJixdmrqD8r4eb0RB1GxS6xFRhyQdYMcl1y5c4wf+xl/mfe9+O5QwHo6pQhMrhmzf3OL4XXdizbi+B2rGskDWSL+qf7+SllC3nUjp8MHjnYPgsM6gvCcoani88igV13Nm6xkUhrys78cHQ1V4jAkYb/EmECf1iMC6gHUghccJUS/aALm1WOuJ45i9nT5vf9fbaI3GDJ3CJm02+mM2t8bs7R3Q6GQMdncpzl7j6LFDRNLXxNIQavREuJpDGDwSiaSeFwshkcESiHDB4fHYoBA6woQSU9bJKTioQl0rWyewru4sna+RASdq4qhQmjxYMqExFlCy5hi4+n1nWYtYaUZVn6vnr1AsGq7sX2V80Aeh2djpc2Av4a3DB8mFK68zPT3Fg/ccQmmFoEZU3tDuhje4MkEglKqRCg/SB4IISASxVKRpyqGpLq14j0fe9200WykyTkAooETQwGNhgrWFoNBxSo2daayrORwAsoRO2+GdIkezubVFljax5ZAkisgSTdxps5PnjNfXcMOCrY0tukXOxf0B04cPc/n8BWZ9qIupSYHpQ/1MszStuRvOolRNEpY0cVTYqiSkiuHYMC7HNWnXOqQU6CgBZ1gfScYji+zGJFJgg2A8lhgJ7fgNPkJAOIGXjtx6hIqxJmAcvPfBUzQbGdLXmELlHEMnaaYZpTAsNhtgHXNZA9HTyJMd+jsDRlsFV25cQz8bU7mSKNIkcYxIFInSpGkDpEbr+h8bAsbUao7+3jZpmlJVJXme451kb++A/d0hK2IVKQLf9OH38dj9p/jUp59D3L9MkRuuXVvhyo0riCCRQmKtJS9zKlsSlK15JsGilKwVTtZiSoNQEVMLh2m3GqSJ5tWN5/AeCIrNyzv89mc+z8FoPEnmMTqJibUmyRpoVfOCsjQBqXE+QJAMhwcMBwccajSxpmA8HjMaDdk+6LO1cYBWdQH2ke96krlDx+o9zwf2dge8duYSl66C0Bl2XOCkJx+P0VoTsDUiJh0Sha0M3nj6B33mF5dYmJ2n2Y5JY82rm89xcDBge3WICoHPfv5pKlMRJ4o0TRBKMdvMEFLQyJp4L4jTGO8llSmQQjEej1ACSlNRFWOGcsRwMGBv94DVlTUUgbc8eD9PPHgnv/XbzyLEPVST5L+5t4UWIKOIKi+phMXkYzrTswjhqaqcVquFcxZv6z20ynNmFo/SbqU0GxE3x1fY2dsBJGuXdnnq2edY39jCBdCRIElT4iQhAFkaA5osy4i0xniP8w5TOXZ3NkiyBsIbRuMxg/GQwc6A/Z1rKCXRAr7n2+/n5B13sr62T6BFv59z4fINbt68RlAaZz02GIrxGGtznLckqUZMFFjGWpxxmDKn25tidm6BVqeBNf2vfDEA9cw3TGb7BD9hiddwb7BuQjaT9WJG0B85BmOHc548N1TGY6zHBYnxAulr2RR4lIQQNFZoXDnmyoVL/JE/+jEWZGDfSDa3d7l4dY3ZuRn2D/YYXLnF8qFuTboRHmTNBRAqQinqpBkEgpqxHwAhdN2sRYLgqNn/eKzTNUrhPdbWyEWwNd9gMDIMCgeVx3hbE71sDQXayQTCImrylQJhNT5IhlETH6DVbGLyMW958D4KGTOwFRevrLO9tU13eoad3SHXtl/g0FyHkyeWCNg3iYJKgK/Be4IyKBXhvJkQJBXBywlBSSCkxKPxLhCCBVGrLFwIjAtPZR3G1HCs8fXsu+7WPKnSBKVQKCrjiHRACAXWU9XcK4xxNFKQUczazRWuv3aZ1SjhwvqQra1N0naHm9dvMH94Ee88w/M3eGxQ8dD9y3j/BqkiTORNAYmEUKMcHoEUpp49hoDAQaiLLYujkJqdUcn8XJt29xCNqVmmFxZ4+lc/XhPUXN2lO++QQuKcRwaPqCQGkLFCBIUVMaXLGBWGOJYYa4i05svPvUy2tM92lnFrZYOsO81wNGS/Klm5voJQimH/gCyN+EMffAzn62RsTEVV2Xp9hUBl6qLZFzVnxIdAZQz4QFEWNJs9xqMROji29ip+/J//U/7nv/oD2NIQQgkETHuend0h7TigpcAJj9QxpS0pDwYkcUQca+I4IYolgRFeCHKrmZ9qMdVqMdOOieIEISTLJ0+xOx5xLOkRaU3DBdKq5M4s4+L6LXSzSZYpqhKkCmg1kVFO2MdSqpqPEmoiYJLEWOnwFoTwfPQ7voeFkydRUuK1JEpisAbvDPc/+ATnXniawdgiWxrrJRt7hv0bO7SbKUksaWeBZuZoZJ52pkmTOqHZEEhEzZ0YlBUjE7i1tUdQgoeXjnI0y8j72+jhiMgEKi1oJQmh6yn2RvQrwwtnXqfTnSKKU7JWkzRNyMcFSdogSFWTdUU91lOqVnzESZNRPkIIQZylyCgmbXWRWpFGKVmacNepk1y5tsLYS148u05ZQUnMwV6f0cFeTbitCvb39inNmBvXb9BII1CaqU6XSCmOHjlEr5kRnGeq1+L4qVkEivT5tEaaQgQq4rc+93niJKWTdiAIep0pdBwhpCbSui5IVS2ljrTCWlcXOwKKvETqmLStiRpNGp2KNElJ0gZZFCGU4tLVG+z2DVo3GOUVJYK9zT6j8QahqhgXQ/a2t9A6JtESpRTtboc0Sem2GhxZWqR7dJGN3Zx7HziCrDnURGlE06XsRwaL4Dc+8/kapVARjUaTLGsRlCSKJFoneOfrBkiA9VCVOVGkyccjnHFEcYN2FNFotvEC4iglTVIefOheXvz3LyLmmrx4bhVbSfaHhr1ByXh0QFXm+MrSH+6xt7tLo3GTNE4QQjA11SOOIg4vzDI7NcVMb4l8bLnnvmUE0L6cMShquaiMFC+eeR2pNd3uFFEc02610UnKwUGfNOvUDfGECCuERCqF1oZmu0N/2KedtciaCpWlpHGHrFHLDLM4ZXZqmvOfuMDKlGCjXzEuHFZH3NrbYTQaYoucoizY39shTSKur6wQSVXLP9OUJI45eXSZxek2NzeG3HnnIZqtiP4g+coXA4H6I63npTXRI4SAnNy4DY7ga6IdQhC8ZTC25JUjBDfRGnusBeMdJngSFNZ6oIY+ja9nlFmjgTcFG5cucuXVCzw3qLh06RoqaZDcWOfWrS2ss3zXn/w6GlmMFw7hBQjHBKxHBV0jAjJMtPd1sndBob2bgOo1pCulwLuAczX8r3WEE5bCQV4FSuPw1uFCzWD3wWJcrSutvCUCSicAhxARhQ0M+mOqqqLV1FQhsHXhGherS9zaGNMfjdnuHzAev05RFIxGA44fWWL58AdRStfws4gonEH6mmvgncA7V5PKgkcRUFpSX7KtSZJiQswLAZ3ElKVgZB3N4ChNnSRtUNiqpHKWgMSaWiZqbIUKb8zrA4nQ4DzGR5SmIE3rTSpSkI8Nzz7zIn7xEHulob+2xuXRZXSvw6uvnWOc5xTjEUcPzVNV88Ck85dRvQEzkU4EgZee4CUhKLw3tQacGmGf1JRICX0j+b6//hPEjYwoTkEEvviJX8BZQ+UEInhCqO9HCEVR1pLPIjikUbQwGAvXNnbY29sjTWKEb9JIIm7u7tGymmGjjb95g+LmGtcrS6UUzlqsK1FCsL2/U1NQavSf0aggH1cYU5P96u6eSZEzSTRCoLTElwKDgyonFpa9seXXP/FJvvd//Sv1eE1KdATv+bqP8S/+/t/DpIJuQ2Od5uLKiHammGpEeAE6VlgfKIemRoxChA+SR+46QTOJECJQWDBOMDU1S8dWSOFR/T5tJTm7ssrdd57mSJIyNTPPYOMCEZNvOQS8c/gQsLZeV7WJwkT5IUCGMY0oIo6bfPSP/XG8dFT5mDIfE8spQjAEH/i6b/rDXHn5GQYmUFWB0tZjQKSg8rX8aa8/mhTqBWkSEUdjsljhxJDBqMB5xYYx5Lljb2+PmV4bEDiV4qs+xg8xJqDSJrGKWdlY5eCgBGo5dHAeMaNQOkJIRWkMk7SDiKOJxwfISXHc6kxN/AP2caFuDEpjoSwpRMFwoPiPv/Y59vpDlhYPs3fmDFnaIopTglToVhthDcqkzGYNrKlqVEJNig4f2NrZ4kPvfoJmmvDcmRvcuHaNmYUWVWFrJUfl2d/dZTAcErwn0QUyaRBJSV4aUglRpGoiZqyRQqIQ2BCQMhBnGTrJkKpPVRRU1uKspxjkGGPZ2d5hZnaan/nlz6KTjCRtEscxjUa7ljrHDVIENk5opDFR0qjvQdYjVYC8LFluNDgy0+UgL1i7cZPB4N5a1lxYIh1x4eoa1UBQiBIh6sLY6xQdVaALtI+JdEYIIOOIgEIFj9YJZV6QNjvoKObmjWu0Wh2McVTWErylGBcMVcyvf/45tv7sM3TvPczmH7mD+MF5GkkDLyRR1iaoGJ8YulmKjFJ6U70313nNO7Dce3yZVqMBWZPf/K0vcPreY1SVpTKOOMrYWNlmeNAnBFn7hugY6xOiuEEqa14T1KoprRRBaIKUKG8RWtNpa6Su0bKqtDgLlSkIY8fOXp+Zbof/8J8/y9r3Ps/i15+m9eEjNJZ6Nf9NJkQZiEgjsoyo0aoRDyUnEnMoxjnXLp7lI+96jDTSPPvqZXZ2drCuzc7mwf+AYuD36RUDHiEFUoi6M7P15i4CWB+wPjCQCT3rqMjIOosc9C9jXD3/ME4QfE3w8gTChChXuIjhOKfVaFCYnKvrm3z2Nz7P8NAhwvnznCs8pY4Y5zmD4QH+j39tTcSRqoaaZai7aCLATkBoJmx+hw8GJeoZtZYS74BJYvLeAZIQJEQx/VHOZtFg6dAcWWuKuUPLnH3qU1QH69gQao5AVMM33jusDAQl0SLggmJjaxdrPbYySCH47G98msbJO1i7fIN8d8C6juiPx/iypCpzFnop1rpJcq/HDdIrnLQEaoKRQONFTTaUQuLerGdUPTLwDuEt1tQysa1xid2Lefejj3BkZpmZw0dJm4qf+0d/E+uokY3JWCR4ydgGlHcoLZG63khLJzjoD4jjmCxWqEaEk5qrm5tkFjaGBr16na5MWM0LjPE4V2LznKX5mZr/8UaXKTxWeVRQgMKKChVq8mkNjQbExMDHewdCoqSgHSlubW7RaE/jncNWDiEM977rQ3zhV38FTY0uWOuorGNc5aSRQmqF1lEtGZIG4zSD0ZilqQZHFufI4hgpNa+trfFo1qCnZhllKdsrG3SzFrOLy1jv8MExHvbZs4bN9W2mZjtEkSaOBXEiana9rJEbKQRBOJRUE05BDcHOJm280FhrQQu+86/9PQ4fOVkXO0ohlQaf8/b3fBX/4Z/8I4p8jIkjikqwsZezsulIJ+qTWA9oZCm9bkYjhqyRcFCVjPKKyjhu7A65ev0mvakub+3NcPL4KYYrK2g7QnrBiSOHkN4zn7YRwxEHB2NirYCAKw1a/54HhHPmTS4Qvta0eBFqVnbl6+djHfgSayuKvP9m8XfH6dNYD1JpjK2/fe9CjTyEusHwb0qfBNaDKQzDsk5eBI+MYrzXjPMhPgSiJEZHCemp0zgF47V1cmNo2pRyZQ09LKhGQ0KQDEzF+bNXWTh2jKMLM8RZE4BRf4DUsuaUqJgQIASLlBNzHu8xtgJZF6U1D0mjtSKLG3itsUFQOgfWUlQVSZJgqgqlFFWRU1QF3ljyoiDLGiRao+OoVuI0ujz3+nWMCxhTNyq/+POfopu2SPJ9yn7JeGPEcDjCG8PqaMB7H3knF579HMODPlkjrf0xZE3kFFEELsLLCpwnSTOCCBjja56U86hcItOIOI4pK0MUZfX7pWawW2upqhwhI0pToYOnMhVFVeGcYzga1bB4FBEphUeweTDgT33fn2dve593fvWHaHbnqMYj2kkTubPLVBDsEOiPxxhTcvHCFR78qqOUqzeJkgZKxUSxwot6FBolGXiJ9wVFVZKXBXjJaJyj44TgAsGFWqcfK7SO0UEikxb912/R/7vrLP/E+zFVjjMgRS1pt76WcpfWEvoDkiQh1QqQKAkvXbzJM5//LKPBiEff/X5+8T9/imaUYMstqv4OxWbBaNDHBUWV56i4xebaBlvZLWZm5xBKMujv16MUGcNkVYfgf+85eweYGrmlRpUjrVAuoNIUOyGt3/jl15GfeY3j//yDRFpT5DlxXEuvizzHeY8tK7JGo/b+iWPQis7cIk+/cpVxWdGY6vHMS6/jcous/gcQCEMIGGPxwb1pkBPe7K99vUkQML4mCq3vO/70X/4HdGbmaLab/MM//mFKYwlS4kytIR+V9Qwq8hIZG8ZWsbd/gPKGTiMjarW51d9jpjvDxnDE8aSDmF6gX4w4IDDsD+j0OggPXngQHhlqNXfwsnbkcxbjHNYYTOUnZj5uYlajESLG+zec/+oHJ1VEKgT9YcUf/ms/XEsAteDoXQ/w43/nf8VZXUs2dL3BvaFHJgjKskBKyfb2GsoJWjIim27w7LXL3OXAHuSMb91Ct6a5Z2kJGxzj3S12Ll5ld29Er5shZVIneuNwKkyklB6ogAgpLc4HIhHV4wEPEg+hxgiQCiU100lC5gq+4S/+70QqQigJ9oCfFf8fKl/VsLyvsBP7PCkEQQaCDORVTqw0hZOYIqeTCGYaGY1mk97yArv9A07PLTDjPbe8Y3GqRae7gKkqxqZg0N/muS98kYWlj5Cldbp2GoRT2BAY5zUzOIh6ZqdVTejyoe7QBsMxzd58fc9esLZVsr21RzNT6ChGR4H3f+t389M/9QvMNRVpXCecs6sDug1NK5N0gyIKDl9O+Ckm4sTiDLZyzPQ6DArD5v4BrdYUyljc1hraQ9ZMiFH0hCDPc0Z2jKwq2r0ZPvW5Vzh16hCPPHwKISFJNEp5pIgQXuCRWGcoq6Lu8oKru5BJSlWRJo5bvOv9X4N0Du8r8sKRJooQPCrIiaFUTFk5jPfEUmCEwE14HcPCszfM2R3VMLoQByAUhbFkaYu9sWFQFIhRVI8LpheJy4JidR1bGLKsw0F/QHMqparGxErVChJb81AcNfFyYjQBokbsrKuJeM4LdJLUJLgArjIEW6sfBv2CdichBI+PIqwL1FqUMEn7DoGv0aHAxJeECem1fkZBCCrrUAS09rTjiC1rAE8ra9DrzeGChZlpzM0VQrA0EsE3vO8eysrxQz/9ArrRwBjH/vAqrWaLXqvJ3JHjKK3IxzmD7XVOnLgDnXYoyhzvLUVeMs5zqtE+ReWQOqIyjq3NVV7+8pd56zveRmdqnkjVrpq729s1s0dKVq9dQArFkZOnsNbjbIUzjmE+rln3cUzqGrVBmxTc2OqjZUBRq5Ke/8QnOS4i/vjfei9j4xiVJT/2S5fqTnRvl+Wjy1x7KWH56CGa7R5aSHb3tnHjPsdO3EcIno3NFZrNDlVVsbOzzWg4oNVKqYxgdDDk5tp1qiLn8JEltkQtx24lLawPaK3wKKSsvS20ClhT4XzAWYtzDmMqdBSRpglpFJOXjisXz2KdZG93i739ba49/TRzw4Jv+svvpHr8BP/5qcusbNUE39HY0G61aIoe00tLRHFCVVVsrt9kbm6OuYVl8qJujnwIFLlBkSN9oMpLrLHs7+/x2gvPcvf999HpzTDwEjlrYVUST0fsbu8BbyC89TjUm4rSOKw1+LLEphk2jYmiGC0kK9u7/Pwv/iesFUwfvZMbr59lfmD41u99knRqltw6/uXHz6LSmHGeI7QibTRZWphhYfkoUimGgwGmv8WRk6cRSrK/v0cUaYqipN/vM96+hUwzhIwYFzk3b15j5fJlHnj0QbwxaBdBKojGCt2O2d3erkcm1hLFCmsM3jqMdRRFXnMdQqDTaZMmKa3eFDe292ppcZKAcWy88Ar+5spXvhgwxpJEtaSqds1yIMF4j3a1/rg28InIYk0oDzh05714a/DBES/ewdUzZ5lOA8Y7DvKSTDuyNEYqRWEMYxPx0NFDLM33GJYVcaNBwzoG167SjGK8tejBPq6/y9T8Ar/+W1/mAx94O512+qY80IWavKg0QH2dla2oJhrsWhGR4KjqMiZAsDWj1XtPWVriTpM41gxGfQQpVTlGloJTj7yVtaGgm2ggUJZgDyqyJNBMNCJWCOEoC8GDp09iQg25r+zuQKvNKO+Tes0wiWiZknjQx5oRQSpOPPwEv/brX+R973qUhcW5ev4kaydFa8DLNzo2T/Aa5x061AS6gMO5yQc7mX+G0hBJwdrqDkrGlPkA7y2JEqwMJaawLLZibIjpjwRKFCRS0GgmE3KopHKBykve/uDpiWWzYHtUL8pjM7Ps31rD5Z75TpdqOEbKA9x4zEHZJ5o/xL5RfP6pl3nfex6pOQghELCgJMHX45kQAipStUsbslathEDazAhRginGgOSP/k/fy7C/R6vZI84ygrW02h06ccQwdwig9BHX1w6IdEwUV0g3QMeaTjsljQJxnNIf52zt7rNXSa5vbXHQ7/PVb1/CZw3s/gFVBdPtFgMLw51Vxq5iO7ccffxJ+lXB5etXGA33WD40Qz4uGAzGdXcWAsHWHYF8wx47WJyvUQKExLjamOpgXNWcDWsASz4eI0VSKxOoxx1COsoQJqodMUFMonrU5eqC3BhT816kxDhLcAanK4KtHf+8d+goRUUxTHWwa+vkvqIYDVBKMhxssekcyhmco7YgFh6MIlKACHWBGMArOSFDRihZj3wa7TZ7O3uoWGJHBVmWonSY7BcVAkklNJoCbJgURSBDvV69r4lPSEmsaz+OovLEkSJRE+Mw72lFjqIoUErTarWI4xhvAj4IenefZvuVs4xGY1565VWkF2TNZm3GRIGOYqbnFki6XRB1Nzk91WR2ZrZ+LkJghmX9rBJFL+1RxYKO86g4pepbWk7jTg1pN2Bufg4BpEmMD4Lt1Zssn7iTf/9Pf5DSSv7y3/4HhHLISCb4GJqp4tmnn+HE6dOsD69zzwMPUeU5rWYLCQRr0VJz16lTxMHxwkuvYaxFRTFZllE5i5SS0dhy+MSd6CSrvUWSiEPLR5GiNmxzTtFstvChRuJm52ZppJq52SlGBRS9nEZTs72yRdcEZlvTeGVJ0trHZWdzlVN3P4wwJTZ4dvd36fe3afRmiaVmd3OXzz/1u3zgwx/mhdde4p4776LXmeUb3v0IXz67wsOPPIqKYxbn55EznldfPVNLeYt6LzWVBaWIkyZKtBEqQipFt9ej250mihVSxBTFHlLUCGmcabRosawVOm7gnKPX6+DKMVmWsjA/TZw04a/Os/N3zpP1mmQXS6SKCPMZHBI0Ek1e5IxGJaZSrFy5XDv+aUmzN0OWxhyZanNkKuXKep+FhTkSa0mHI85fuYaIKqTUZK0YJ2MoC7RuMDXXo9NtgBBorVlYXEQuLKDiuDbjm/i9KKWYm5ujmdRIopcNSlPRzGKakabbgKXFRYQIuL/dZfsHv8iwm3JoJUU0ISzGFCJnMNqh0ZjCe0+M5bc/8Wu8+71fxdkbVzl16m7ajYxmt8d4PMaUBYlKaGcJ4+OLX/liwDlLURS185r3NXOVGgK1tq7Eaol8RKwUPllgf3dEp5sSR4pv+M4/z9/7i99H4TzDSnD29W1ajZQkrmg2UtKkQNBgWIx57UbJ5dV1vvZrT2B605Rrt7C+JgoelDCcmebu++7kxvWrPPfcS7zlyYdJ47guVKSZeIlPRp3OTeBjQ1mWv28e6icbpq+92CcGSLVZkSPSms78Uc6evcjp08uoOMZ7Q+EEqgQtFVuDgt1+STuN0FoRR4JWIyWOY/ZHOYPCsT4qWLm1xYfvuAcdNaj2tpluNhlbx8HwgFGaIeZnaMy2KPb6XLl+i+n5WeSkC/MKpJcTuNkzGuaULiBFYDAqwDsiFdVJSNaGO9Z7dGxRwiEaPdZu3mR+rkOSJIgQWDpxD6tnXqYMNWHwzNkNEIF2IyWNC9IkppnVhVrlFQfDAaVvcvbmOvt7Qz565BhFa4rBYAVXjIiiBoX3VP0tbhmPml9g+Y4jXLlymXAl59rxRZaXD6EEOAkET5xIokijRI0mOWcmigOFCA5ncrwZ0UgyorTDH/tTfwbvS/oHu7VNrLMEa9BK4xAYGzDO1bPUUEu8SgNYj8NRmRJCXnM+vAVbMS7KOlknMdN338/ul59iZIYUZYFIU3aKnEGSIeYWeOG1Fzl2dJHZ2TadTsbZC5dqT4nJGgouUJUFQUhi+YYTJrUJkXMga3hTCU3QmrJ0uLygyPPJuReQZrW+vhARIpS4iXcUwiEB50ucFbhJNx1CwLt6Dh9FmoACb5hKE1YDxFFCpCOcr0mpU3ecYvzKKwhqhch6XhAfOkpYuziB72WthpDU8s+g8Kr2oySOEFqjVESkavVGf9hn2B+QdTJA0mxOMc53GY0K4tqjmPsef5RXn/kiBFDWYr2pSYl4rPWT7zRgZM03UlpM3A3q9yhEQHjLA3cc5dyNWzTaLXSSgPe16UuzQ9RpsJdbfvinXiEWgukjp7n7eI+N0aD2D9GCONGkkUZJSRrX3IsamQ30up3aTC3URXCpFX67ZPDUBqql6cYNnph/jHDL0D7RwaYOrQRV5bnj7vu46567+dAjd/DS9Q3e/8jjfP6153F5CUKQxA3e/8EPESUR+7v7dJptdLPNazcso8qT+RH33znH9CMPgPf87Kc+wfqNa0xNz9LsneDYsQ7r1TpFnqMiCd6SxVEN10dx7VRpLVEc0+l0cd5Qla72M4lTGs0GpRnQTJqIKy2WwiJhyhJvCOyGwc1XRKfa3H3fw9x9710ok9NrRrx4doWpXofRcIj3cMfpOzlx6iTNZpMXX7jBM5/ZYn/zIr3Wcaq8wzNPXeIdH5xm6v57QAr+2U//JFQ5c8fu5djRBby0WGPAQxxHRDIQy9riXgiJkBJTlUxPT9cdvHcTZa4gzZqAoJqMY976jndTVSVplkDu2f30GvPffze9uWm6yy2mp5YIZ9dYfWWb0SOBNG1grCNJFI+/5S2ISJEPBoyqioSM557dZnHmMVqtXXa2LYsnTmGo+I0vP8ulM6+SJTG9Q3dz+uQUuztbpGlEHNej6SzSdfMxMcUK3iKlotPpEKjVBG+oVtIkYjAco5WiubzMocNHUSqQJA2q8Zj+6+tMf+eDHL9/nmMnD5OSkD+zwjiK2T6cMhrn+OCZmZ7ij37Hd9T3H0fMzMww1Wrz7FXDcKC5c2pAa75J7+H7iUf2K18MVOYN4xyJ9IpEgVayhtgnBLzgPT6UaB3x5/+3H6LVTimKkhBplk+eRona/KayAesCByOLPSiJM4EPhuAOMMaSq4TS1QZEjTtOMdrbwwyG5DrBzU4TUzI/nXF46VGSrMUrL53j9KljtLodhKyJdSHUid16jzP1HN17j47AGEhTUR8q8oZ9YAjgA0GBMweIOOHv/PN/Q5QIRqM+TSEJvjZTGaxfnThp1ZvxuAKblyil2BsFAjmmsiQTT3Mh6nMcOnfey/74OQb7fYzUlN1Zsl6buUNTHFmY4o47TpLECZevrJBqWDy6hBTgcDXDWEp0FCGEnSAfIGWMFBNiF79Pu1/uI6OYH/6pnyRt1jKdVquN947v+K7v4h/9pT9HVULlahvOUWkIuWNvWMO4SeKwbkCgtoFtuMBwWNaOxRKSoyfQLqcaj3CuIDew12gzdXiejZ0tYlVxaGmae+59gKAbnDt7lbvvOo4IogaKvaEc2Ylz5URFICcW1xPXBDEhddqiwFa1ztp7z+BgRJLWyUa32xR7+7WhTVDoiYTQOle75YlAWZUTeaPEB0MUayIsriqRUhInKQ5JfPwI7pUzeAHjUcF+nDC/eJgL1y7xyKP30m012O+PuO/Rx4njhAuf+U2KvABqkqoE1IREONHI1I6CQiKUrklsaUKaZtxaWaPdyRiVlvmZFghPURQAfMuf/JP8zI/+8/pwH2F/j6/jwoRX4vFIpLMQJJGqrcIlNYLQ0YEskjRbLdJGp/4WpEb1GojpNvagIsweoqlSbq3epGsVKEmkFHoCgcexRkqNVIHgDZ2pBvlghLUF+Thnc3/AqLDs7e7T6TQJOrB2a5uZmXoWK6VEa8mf+ys/wJ/+5m8kjjRmZGtSlPA4F3DOY3zAy/qb1DJMRn68eWaHFIFYCbSSRMHRbHVRSYa1BikUlXfMPHgP55+9zKaZITjLjQt9jh5qICRoXSfQMu+TdVN2V7donVwiCpBiEUGgO1N14xAc41GB3CvZe2WN5Y8eR0WSZz7/mzz+nnejSdj77CYLb1simW9gjCcEz9OfeZZr23dwaOkO/trf+mnufew4Dz9+BILHWkMSJ0gl6LXbjIdD0maHV6/tkXan6WrFtdUhCMNcN+HVtYSqWmZ9Q9Hcd8zO7RJLxcatNVI35rd+4xN868f+BGXf0VyarRuz3Rs0lk+ishRBSlnV8+Vrl86w8OgjpM0Gl37mFY583d2oruT8669y+t77aDQPM37tgPXn9pn++oxYCQIxuat5KD0xxdzMDK6sJu8z4rOfeY2Xv7CCVg2eePwe3vH24/zOb77IaFfz4//0F/nIt76Pu++YYzWfJ1SGK6/tMLXYQdiSOIrY37tF2oHIKbZvbHPszqMoqYgwyPGA5uIxhKyNvfK8oBgOGezuMLt8BBoZxhhGu1vMT8+jlOLav32Z+/7C4zgFe9vr7F2/ytLiHNzX49iJRb78I5/l8LefZrrbwVhDnKRIwHdaXDy/yr/6od+gMmOOLU8zPXuCT/7Hz3Do+DHe9tX3cX1bsWkO4SpBdi2wvLRNJBX94QHNTLCxu8rifQ+wu7LK4olFgoPMD9DtKUTahOAxxlIUOfu7a/QOHyJqR/UYdH+PdreLkIpIR1z5yVXu/NP34eO6qfvib/4UX/+n/yIzH7uH9S+scvP8NY6/82jtVCg8saqt3ZMkwZcWYww3djzBpky32rzymdcZbd5iNHRf+WIgy2KiSBIc1PayUJlqYkgSsN7XDznUJXeWtcjznPFojGi3kAQq52ufdVcb/1hbYQ0Q5VTG09ARSSRoxhHXgyeKE4wxLDz0ICuvvUZ6+CRJAu9899vJmi32NtcZbKzTiGOuX19n6YhgqldXkUGBt+L3yGjBEyd1sokjkEEQTIUPvjbH4I3568RyVsXoSGOqnLKyECriBD7w0a/nF//ND2G8R4SAFvX9165YEmMcLoRa0+A9s1nCmnO1oZFUzD78IIOnv0R66AT9vGBxrsv7v/bDNNOEG2fP0V9ZITcBNTXD+s1NFpdnEbJEyrobShKJVW+4ENQn+1UGmEBrakJeEdJSqQ5Zq4U1Y8qiQooCpWH5+ElcAFOVSARaK9RkTFIjJnWVYJ1D6QitAqEqqKoa3lY6xTlHd3mZ4eYGwQrSI8uMt3eYjQXf+Sc/Rl6VtDtt+tt7XL16Be0d58457rjrJFJKrKnq4sv7icwxIJVGIDDe1AgHEhFLlNCUZcV4PCLLmhRFTpzUZj8f+ZY/xsf/1T+ncIHgHDGBMji8nXgPCEUoPdZ7klj9HoRvC47MTnF1a4+gasMV32rTO3aE7bUtOneeZvvWFufOnuOd736IE6dPYcqKd73jSVZubfPSyy9P5KgKIWoVhhW1u5hW9UmCUmriqE6oQtQmRPsHG6ysVrR7HZQzpFqTpDGDfp8krW1+v/YbPsZP/MsfwUhN6t+wzA4YN1mb1uOCRSOIdG3UBCBkfX9KOB65+xSbhSdNs1pJQ8AEy9K9p/ny02eY7U2xt75OZUoa3Q7C1f4hWiikNmANWewoyhJrKu796u/gpU9+nNFgm35eMs7r00B/9Id/mH/yL/4x3juefeZF3v2eh4mTulsaD/tcvHCJhSNH8Pk+u/lw4qHh30Ts8tLgdYx0Fq0FSkoaiAmxEiT1SE54y0KnSZykCCkpq5IsS2ojHuk5cuIQ5/IxblxXYlNZgqkcp48dJVRjRiPHmXMXsEHz0se/gFzdwE5rTt1/kjB9H0pHXD7zKjpO2f+Zs2w9VHBXcjdKzXH5lVe4b1oyaB1lf8ly61/fovVtRwHJs58+y29/8nmEMPyNv/69PPXF5/jZf/fLaPER2lOCfNBn7tCh+uApCb7ylMMVqnKIvdUndDt86dom/dGAh5ZqV8IoyShMiZWBVAqWZqYZ3HwF0e2ws7nB+evXyfOK/q/8PI0k4fRds9y6fJP24gkwhisXL9Dvb3P+i79Lvv9n0F/Y5+bdm4y2ehTn+1x76dPcPHORE295nLjdooxyVs/d4sb6LYSxNX9DaVCSqqwo9/cJPvDQHcf5+Z/4HYpym1bjMGvr6zz3wgHT8xHSw7mLe/z7f/ULvPNt0yg9gwsSfI452KfZjDl9chk5XGEku1w/2KMUCWc/+ZsML52nPdfl7U/ez+sru6TtNudfe4VO0ub55z7HYGOTj37nnyWOY5xzHLz2aVrH7qY1XEK8r8uFGysoGfHa0xdYXXmNW32NVoKpmXmGhwO31ncwZki5P2BqYb4eo4aKz/3qdfb7a2gd88EPf4TxeEB1El585Rz/9l+8ROeeO1BxA2Pqjvzo3CzeDbChwIzHDPZ3eP38RcaV47WnfpNobJg+0mCkG0zf9STOVFy7eAFvDU99+rd4+3u/mhP3PUwIgf71CwyDYXFhmWa/QfX4FFd3NpABXnzqVUYH2zzzuz9HaRZodnvkL25z45ikHA+phjkLy8t4XxHFEdUg59LFy9iDE4jK8anPn+fWlfPMzU2xvJR+5YuBMi8RTiOReOrN1jqLmpxUFwKYyRfstebpL3yOB598iHa7i3UVQsDpBx/nxWe/jLAVUgRccBTOI/KadW2DIAZi6bjvyCGiJKOsCmSkOfTog4wqxzvf+Q5uXLrMq6+fY+nwEq1Gi/3dAQfjivM3tvi6D70VhCKIMJGbBay39cYzrmpbXPyk66w7U4XEUUvGggBUQgVcPnuGqYVZ4rjWalprefcHPsDHf+yfEbytT8QCSmswlcFR22dmkcQTUEQkkWe6lRJH9amGRJKlRx7k2uaY977nCe44dYpr569w7rWznDh+jCRu4kzBc69dYH9nl2/8wx+s54vSonXMcFCgdEA496ZXvFAS4ev5qpGiPutAgbSOoiopRiOkTMjzkmYrJiDIvYJgiSdnNWhZm7qYCVzsjat5CUpO3PQcdx2e4+zNDeIowXuLkhFLD9zPucsrVMbz8EP38ugTDxE3UqTLGPVz1i9eYnlpHqcy9vaHPPf8WR59/K76eYWAnMxnlKj/Hi8E6AhEhJYCmaWoSDGuLGVpGA4LZmYbjEZjoijifR/+Wj7+r36EoTVEvpZ/1mc3BAoXiAhUQpDomsFen5tRn/y32E5Z3d6h0WhinQUh6B4/ymo+5tbWDkkj5n0fei+9uQ73Pfookda8+PQzDFZucHyuzboHJyMirVFaorREC1//XQSCq7jjgQdodhq88NlP0S8rdvbHjIuC3/q13+Ebv/H9RDGMBlVtVR1FKCk4OOizcOIOioNdhtujSaL3GONrVrEXlMaSCo11inQybyeiVsMQaGlJ0YgBVbPcZYozOUpKDp28gy986RmmO13e+sR93Hz2k6Qq4IxDRhGRmowifIzHEicZZ597nvEox9laIhYm3IUXn38OpQRRLHns8bsBycqNm7z+2uvcunWNpaV5/srf+Iu0Wynbt/b4mR//aZ7+/FN18TVx/fTB4kKgKj3xxCdCSRAiRnmJVBAJyLKUqqoo8hH5+IAkncbZCpSi01ScmPWs3/IcOzJNL3b85hWIkxSlJVkWESUzeASRjgnzBTfGB+xcNmTjJjqKWR8ucXh1xOWHZrkzXeO7v+c63/D+Mzz8xNt46dI6N6su3hmOvGWZ4aueRx6O+a1ff4bK9Em05ktf+hLNRoNDS7P83Md/jfd91Vvo31pndyyZnZ3Bi4jXX/oib3/oAeamZ7n41M8S7nwnrSQwOz3Dd/6p97D6j38CUwUu3/ScOCQ5d3affjyHSlLElsBVknObU1jnabWfZKsoePnZfaaPzdD1GcE3GdkTjMaKO089xuXndphZWOLPft81/snP5Lzwa09z3z0PcHBlC+ILdE+/m8ahHuf/jy8Q/4W3sXnpCqOijwkD7nrbo4xfu86oLJjZHLHjbzA7PyD3i5S7nhvXV7l6bYwPJUmksOTgOjS6U5zQntJAksww4wp+90JE0mygdITaFPSyDkY1KEaS5uIyl0YDujc3WXXLRK02a/sLLHWaRNMPEe99gVfP3KKzeBczWZ+Xrgjycy/zVWcPGH/PacL2DnO9JtfPXGSj32B33OPmpcucfniJ1wfXeOgTOf35ITf39zl8XDJ7eJqEhG9814BnX9lE+Hl+5dc+RxRLdIDlY0d45fVXeewDD7KzXXD+puOuEy2UD7xwEXSjiRTgS+i6aUIQdLK30pclX752i9npRaZWEqzLGJhjiP1d5o89wfjGBS5Pvw3nHZeeWiXKDrh5tMPJX7qM++6TpOoYLa348mdf5sSdD/AzP/Y7LN3zTu59eIqt9zSZ+twWewsF4eomRWiSNFJsVfL8U5/iyOGjhINN/uhDOf/bl76M8xnveu9jTHUEv/Spr3Ax4LynrGxNLg5+YrmqCEKitEZFEZEUWGe4tXGLc+cv88g7H2Fne0B3KqKRxHz39/8l/t9/8c9RbtxCyJqpWhiHrQoiFfA2IkSCSCuOLRymyHMGgwGiEdOIe8zNzvPZzz3L1dde4dS99/PlFy9w9M6HOPPM0zQ7PY4dOYyxAqlrV/UQXI1cTPyz3cTe1jpPcDWbXgqBReBlLUNTsUIlCXlZcunCeR6bm2J19RZLh6aJlSJuZDSm5zjY3yG4HJTHl/X4w9ucSGtKBJGuzX9kkDx61ym0UBRlgXKOVrPFI4+c4uRd9/HJX/skg1sbzB06xHNnr3L42B08/4Uv0uh2mG63ee3VS7z1rQ9P5CkTo6DK1HM162qJZJhAxCLCC0HQ9T2krZSNGzdJWylJ3CCKI4qiREnBo0+8lS8+/QWCCDhfTiSYFibnogfvSWP1pmGQEoJUezpZQpZ2MK5A+ASvJMdPnaCSiicef5iDvT6ba2u89OIZWlmD+x9+gOvXbiDEiP2iZHgw5MGH78L4mhypBSAUQsdIHSG0JEnqk/iMrZnGe4MKFWtkLrh5/RaHDt+Lc444iVBSMX3sGOXGNm5/B0QxOa+hPlzJCvmmTawNnkgotHyDI2k5Pj+LVgrrSgga50uOnzzKlade5dv+6Ddw/wN30+h0+NxTX2JnfYuZbodGb5ZGMyZtJggzQklH8AVREMQC5MRauhSezb0h00lay8iKClvVRM8f/Wc/wse+9UOkKuKLT73II4/cy5VLl3j5xZfIi33+zPd+O4fnp9lau8WP/OMf5dqVq7XKIFJkKiGMx3XxZzzWGVIlJ66OMUEKMhlQwHjcx/uKZrONtxVOCTpNjSv7fOs3f4yp6Q5Xv/QLtSRKBPAGQoyX7k0UorU4j5SG9sIh+hfP1oXAGwZddszWzhhT7HDu9XOEUDE71+ahR+/gg4ffQqQUtvQUY8PC0RmO3XUXV69eZm+nT8d5Kg9TOqbVyBBY9kYFe4Ocylq0lFTOECqBSyQnTp5gPNhn49p5lFbEGqSuEaJOd55m6tHK0W51+dx5i+weR+gUkgSvI3JXW83axQWqwjGDg1AxtAmuqAhba0x/vmLqw5LXLl+lO3WMb/3INL/26TOcNyMW3vMI1o0pnv8M9754gue3XuTxR7q8eskx3qn4lV/7bZRSdLuHaHdWyVunqJaP8OrZ19HlObp3PMmtrZSLN9YZ2xM89ORHOX/lGu94+wnuvusko3xIM2ogUoFwN2k3U1ZHS+j2MkQZQkdYJAcjxfzCYQ7SQ0QEZpzHBsv2KEDVx2/eoLn6AuPeLOsvPIP7hg+QtC3bA809Dy0yvPo6d55axiPYP/ssV/ehFZ1n99r9NGceJCr6VOM11m5c58yXv8QTH/geLpQ3eOG3fo73vO8xXt8/ysXfeB7ncqwvasWMLSnCiPuPTZOWFe3mLLoqabenuFwWJO0lyJqIOIUoYVyWVM7RPnwXY+Noe8d1uU/lethhRTEq2Lj5AsiC1d1tFvZXuCmPI5JVXnnxOo+/823sSc0rX/gSx5Y6bCRd5u8LvPBza6wVPfTxt3N5p2A4qhh/6iaXPnacxvLdvHLtVRqXz/N9H32c//AbZ7nrwSe4/PJNrl69BaIC4OLVNRLtcbmj3WiB3SVNBb/5cknUOw5xjFApg/UNtMlod6fop4s4b1lYuIfgDHuFwNkKtb1JdO2LTB0+xJkL11ha2Ec3Glw5c57/5W89yKde7bKt9vA3zzCzP+D8wPHWOx2/9vRLvPPRO8mOvYvz25bx6DrZLwyRf/WDrC/e4uLTz4O0dBce5sAsEN/apNUQrG9fp33sbnYO1vm5//jJyZHNX+FioDQeL0EohdYJUgriWJLF9fx3+aEnuf7yl9jf3WZnb8AvfPyn+K7v+zZ2NreI45TXnv8y1g35+//gbzEe5vy1P/f9BF2gk4A3deKcaTW4Y+kQ1zY3GRUDoq3L6LxD99SDxJ0ZStnm1vXXac8v8/Jrl9gejzhxd8Ts9AI6Uuxv7/LquZs88dDx2sWu9jchSImnPhzJBzc5AyCujRtiXZ/GJwRHT53m5vlXyAd9rm/s8gsf/0+882veRbVmiHSCFI4rFy/w9g+9D12MWVtZ4dO/8TugQEWSYC3OVpTe4oNEqxq+bbiKvd2baO1pdZrM9E7y5edfoDCa7Y1tmo0Om3tDLtxcRzZm8EFyeH4RqSJef+0Cb3nLQ7zhh1BV1cRhEJyICSKgRESIImQUE6t6Qx8VfW5trRMnbbR2nHn9LA8/cjdSCOIk5ru+///Fi9/1Gm44APKazeoDo7KiRKGExVpFEtejG2cFkU44cmiJoszJ8xgfgQyG43ee5qAoMcQ8//wZtq9f5s57H2RnNOYXf/tpyuGI6UwxqBxHDi9w5tXLqLiBkhKlII40sa7HHFGrxWPvfC+/+8v/kVFesbs/wDjBs1/8Ik8++Sh33nUC72v5TP9gj1deeoGH3vooH5jpsXFjjU/8518h3x/jgyXgap//YBk7iGxtsBRCjIsUmVLMTvc42Buw394ibWT4yhKnMR/44FdzMBwwGFV89tO/hZZw5fUr5IeXmFla5pkXrzDa2iRyQ2QSEYJDhhjrIBJ6wvCX9Pc3qGxN+vQT62RrHeN8i8oIdlZvsrW1xjNf3KI70+SrP/gWpnptgg0Uec5ya47H3vEOhFbsbG1Q5IZhFZhr1ja6bkKElSrUjpuTMchwNGZu6RA7K68RN3tU07MkSY0+bG4d8MEPfw1rWxtkrR4q1OTU4F1tpTwpmiprubi6y9e+7w8jVcpcr8nKhdff5GMIGdNKNX//b/51Di/P0Wq1+NBH3sPxk0dqoyzjWF/b5datPYaDAadOneI7vvvb+OZv/xBXz1/ndz/9FOeePcNi3GSxk7Ew1SZtpFzZXueVy9fZHxkqb5AiAiVRomQuCbjBCl5p+rZP0u7w0CNPsJ2n3NwcQLbEuSsjhEqJpyuiZqc+udTVBXMIk7NPpERJjSkKOq2YQW6JZpaJkw3EzByxaPF93/qr+KkFdvZ3qYxkAUvc7GF4hHh/gzM7gqm0z3QjZ7QT1d4JwRMpTatVcvPmDca7K8TlLkwtsH75RdotRXv2GIPze8ydepCgtojjBX7lt19jpjnm+maGENCeOcLlVYUTkrlGUjc1wiGaCUmWUpYlKgTQCiEFdlzQareobEbozlNNfQ2j0ZDjiz36ssk3f+NZHr43Iko8z1/fYKo7w1Afody7RTS+zh3HlllZf43+xgrrF5/n+ANPcuVLv8zpB97H5QvnyDfOkMUNzp5/Dh/W0UkXMxpMyLoBr2MiFM1OyUxY44WNw4iQcG3zgCzpETdymlNTNccMR4wgRDHGGpQIiEhhXJu4IRFG0uzOksw9SnAJS1P3MtXNGSnPyqCDUjfIOn+IoMekR++n73IOrr2EVDexYoeVC8+yt7tLp9MiHo+pqto2e3PrDK3ROqI9w3y0wg3T5t4H386VV36R4H+PBO9t4IF7ZtkfKnZGmpnFO7myAj4oOocFUVTbxfuyQEeS4E1tdKclCEneH9Bqd8hHDtmaITz2TZQmZ/50l8L1UYOSrLnFL/zMs8w+fD/ZVI/rao7+9hDpz/Hq+gpCpezaWS7+5o9z9MG3kl55lkQ+xKsXL9Bffxl1sE/v5H3cuPAlThw6RNqe5Tvfcxc/9puG/fXdejxORX9UfeWLARHVUhAJRLoCX0N6MkisdaxcPodIWzh3C+8c5XiL69euc+7MGU7fdZSHH7sTJRRXLl6n25vmh3/6R3nh6Zf45V/5DUYHFYcZ0ZJNZqdmeeSu06zurFEpCboiJE10Ns3G9gG5gf3dLYypWF5cwo0rxnnJdGMa4piLF27y5GMnUEpM2MIS60BJgVP1Jh3HChUcWgWUMCSRpyg9auYwpX2eYV6rJi6ceZk4jjh2fI5Lly6yv3uDuflpPvYtX0dRvtE5VayvrLOytod2YFyORhBHamJ2I0m0pStGiIPrlOIwi3fcxUOdZc6++ho7mzvsNSyxitBaMx4MWJhfZLC1y8reHo1mxubmHktLbaIopfSgdYKSimSiYKjnxp773vEOrrz8DFvrW+we5Azyin/1L36Iv/TX/wJ3nDqOkhHGVLzw7DPcuH6V7/sL34UW8PoLr/Offubj9elhWYT2AeV07duAw7uAcQoqS3cuZn/jEr6apd3q0Znp8uxzL1D5wGDg2dneIZ2a59L1NS7f2oIkphGl7PWHdKdnUEhef/0SWlt8gEYsCaasSajWU5YxFy9fp3Ke0hps8BgLf/9v/iC/8ZlPoDrwqd/5LJEyVGbEvfef5L3vegjvApWTXL++wvmz5yjLEZ28Hp3gLM6L2qlN1u5dXtRnCDSaEQzX2Fs5IG11iLMWj556kn0jOHnHaT7/zPOsXb5OubnDfY8/zrWNLZ77wnN0ez3UxJbYh9png8kpmV7AqKiIZo6xePo+0laXtavXqILDOIeQgmbq+D/+/t8jThTzi3O88x0fpDfVRAjB3vaQG9fX2NnZ55577uZP/YXv4duqb+fSmbP8ysc/wfaNFZaSlGYj4dbuJpsHewQBPkgq71EyopFkyMQTywppdnG7A4ooxWcdHnvsSWRziioEZo7dUfMoJux6L5jwTxx5pTFVRZw0SdoLTM21sKH2VDh172nSNK7HLlnAGEOWpaytbhGCoNVqEQi89MLLXL1yjVG/T1XllP4IZZFz9vwq/VIxc+Qws16xtrHLVn/EWx66m7tOnOTE4hLnVla5sbVFZX2t5lGCNNb1sdAyEPyQkGuGw8BTzz5P0jpOAEZDQ6vbrc1dyjEBiMM+IltEBkekMqpQNwaRDDgzJtEeOT2NNSuQdFg42uHw4W/i53/hJzn+6B9CHZyn1UzwMkIkc5RcR578emwU0bXXuXHzdxGiNiwqKk+vM8X0TBe3MEumBTZAPi7YuPk69x7zfPr5A3R1AFiurRywujbg1vY54qMPARKfB5JWGykV1Xhcj1OkILgEbwriRlSPNb3HTQifKEeUxYzMFNYNiLpt9PEZproJd737T3FUb/M7T3+S+598iC89/ymeuPtJdHeehrqXjXOCwyfaVCJiYXGeQf+AOx//AJVa4PD8DObQPEnwDM0+/vVPsBYtU7CHFA2EmqCgMkbM3c/O6GWiToYMcLB9QK/bqs9QyUe1kVjlsGkD7wOxjPChPqlSUMtttRQ0p+eocoFKm8w1uqwVG0gFjWaHP/zhJxl0miSDDZLOMkLNMj09xxd/6Z/wxNd8jM2D69z78HvqUyIPWpx4v+FqK2NmfpFInkQqwc9//hdJpu/n/e/v8uv/oY1hgEBDsDhgczuld8QRdbqIUJ9PopIMV1WIEHBUuHKLEAxSmPqETlcfFxxFCm9MrVKbOkSRH5B0Z9CtOVSV48U6H/3m7yPWJa8VGY3NdeLuEaJphfbHWH3l3/DV3/QtYHIWG4sUwfGWJ9/BlVHM0vws8wtfDc4R6ZjWzBE2L7/ItOlz5WafzpH3MPrtT2BdVVvbv2nq9RUsBkbDfUgiYi3rWR/Uenenub464PG3PsLR5cP86r/8h7SmZ5hrNfn3P/ZvOHriKMeOz7G7M8DYwC/8/H/mzjtP01KWYruPEJqXr1zkvR/7KOe+9ApiNODm3jZRFNOMY4ZHT3Dr5ha3dkt2tzeZXTrFXnmdlhxgRmMGLtBcOsH83Ay3dm6R50Oee3mVdjul06nnKlkkca7CYVDOIUzdhepGXc27IIinZxjsbCEbHWR/NHEWtPz8z/4yOzurvPXJx3niLQ8RguD1c1e4cuk6jz7xMP/LD/4ATz/1LL/yyc/T3+lzxOcc7vboTbVZ2VxjWBX4YNCtDISkc+Q4P/kffo6Hn3gX995zL2VuuXrjJt2o4o6jR7h49QqtVhdlc+LOHN1eyuvnrnH40GOgBHEjIhIegqWZGqzzKKGwwXP5wiXyIlC6N9jajl//5Z/jr/zg9+Ndn9/9/PMkacWRI0s88fgHKHLLaFQws3iU3OVcv3KV61du4b3Aj4dECrRWSFlzMJRKiJNAQxXEo3Ws65MefTcP3n2MOOvxhc98lv2dPbqz89y4sUGcJQipOLK8zGjjGljHOC/Jq4C2ORJP8LVBTZAxhamYXmwjJYi4hR/U8/IQPHvbu3zhc1/k4oUzzC/O8t6PvOv/bO/NYuxM7/y8532/9exLndo3FotFFrdu9qJWt1otqSW1ZiSPWj0Ye2yPZRu2Bw4SxA6QXCRAfOfcJkhykcAI7BGcKFLGljwaj0ay1Kt6I5ts7kUWyWLt2zl19u3bvzcXH0dJkAygAApgRPVcFVBXB1U47/97v9//95DJpAkDj53NKvsHR0xOTvNP/ut/SrdV5+bVW/zx93+EZaVpPXzA4lSZQjZHtdOl4zhEgGVbyU67nXSIR36dVD7D7t4RW/s1fF+n13IIgpjCxBRrWzs83K+hNEmlMs1g2wDffaLMTkJ6uoQwjhgMPeySQuYqTJ44Q6h+gFKSucUzZAtZNB3cYRPdLJNJW6w/3mS0Mko6Y1E7avDu229zdFSn2TxiZneCVrPJo/vbCJFmIVUkqxvMzUxzaWmBVrfFxu4WTd8lihWxTNrpDGmgmXqSlNcFQ9dnbGGGIIDbH1+lNDJGHOfQdf3JTZp4sh8dPVm1C8mmM7T218mNODzaGZColaE4OkpK0zEMA9O2uPj801x64RKePyQKwE6nKJQLXBQRaw8eMBwM2Nrc5PrtO4SBwOl7zJgR81NjjNsG4akpVje2ieKAo6pHuzegbJcozSf5kGYY4FlG8kQZKnwpCGIdkZ/DLI8wPjPJ3u0DNE3iDyIiESF1Db/fRGqSYdRGGA6mpRHpGmH4JIIbx/gyBmIib8DRiRDv4B6+AfZSiROLSxRzNlLpGN11eqEBccDRc1ni1gGxEDxeWSFSPrFygZih0+bEiYtUh4+RVoah4xFJnSgK0I02O4e7eG6TsegyAoPqxse4jsHcVIaD1hZCmugCol6QPKlqEl3ooMOg4+DWHjMMs0ToqEigiJKjtAcyBj32iJ0WQSDpGhHl3Rjd8NjLQ6cXMzuVIZsv49YfI/Q0AwSlazU642MEUieMhiA83HCIr/eI9S6BNyA0AjQkwcwEwWMnMUyS3Lwl+zMRzXoPUi1CfwdNMxFRQNA5xNRNAqcFuoamhTh9iVI6esrCj8UvGy4jQjQESiUV2l6YWA7jwCdSu6QyIfPTE/xs7T7DSxnC/TXilESXOlGoUdIU129/yonRCQLho2122bloEfTbGGYF129TsrPc3fcgX2P74BB0gzhKQuXJE2RIvR+SczrEykk8D7Ek5ehgGvgk/o7eXo1yeR2ZtZEkN4E8WYf1VdLcqYkIGfm4vaQBN/Q8bOWyrQvOnJgm2FujN6/wqw+QKRMZSZx+A7pHxIaH0+ogDZNm1adxCnQ3JFIBQkaEgU0YdJE5l9A95M8+OGLh7Cmi8C/Kwvhlkd6vdRhAKDRJIgEiEfVIDTw/acNLWzaeGxChOH3+KVKWxM7YnD+7QKGQJl/Os7W5TWVigp2DOpXDGpVSljdOneCvfvUV7NlpPvz0No3GFp85N03bDTEWX2R7Y48vvPwc3VabVr3Jleuf8upXv871n/8bphaXyRSKVLce0zEktcMasxMj9IYORtpgf22bIFRopkH3cJs4HBJLmXQlICCKUFLS7blMLC1SWbiACnw6B3VmTi6SzeW4c+MjTpxcxPeGbG0eMDo+ylGjzntvv02z08AdBDx8uM6dlbvEXsjoxDQHbhdDM7BlASudwXE84rEZPD/m4XaLS089T/Vgjevv7zF9+iKZXBe3sYvv+aQKZSbmlnCbVWrNOputA/LpNG99dJ9cLkvstgg1kEInDC0UydQa6GkKqRTp8hK12i5hGBKSNIl99zvfpdGs87nPPc/Tly6g6zo7O3Xu3nnAYDDgy6+9xB/8w7/PB29/QvD+J3RbA6KDHZ6ZKTJZqbDfqtF2HEKSVjbT1pFCEGXySDPN7buPWTx9nqWz56gfttg+aBCrmGcvnOX2yipSaDTaQxaWpnCCiGw2jaMUoVLEYVIYFPkh6zs1Tr32++TyU8TOgEa9im2nKJSzzGdz/PzHf0KpUmFmaortjSq5XI4w9Hj33Y/YWN/g3PISe4czHOwdsra2g8pX4KjGxcmTzI0WkYbJ2blTDD2Har9Bb9gnkBpoOgpBoFIEKk+mUsattsmNTOI9ekSv3iIqGxzUjshnsww9l4nxUXYsM0mcikTRjJZUKYdBMogFvgO9QxqbMbHvEUeC0ugYGdvAMA1S6RSXXnqWp567kEiWsChViuRGilQPd/n4/Su0u222P6zSavfZ3NylOOxSiAIO2x0u+j65dI5USmdyZIEZGTJ0PbbTGZbPL/Ppx5+ydHqGei+gUCjR6MNmdUgmbmNlcqSzFbbXNpFajKaSWwFNJeuLYfhk2zbW2Ft/ROZgk14v2SjxQoXneUgtSsyWMuLejZt0Gw0m5yY489R5zLTJg9VVNh9vUhkfI1cq0+q47Ff7bD5ao9844q+/+DSREDiahqFLluYnuXF/i5WtGk4YkzUlUkJGt0kX0pTPP8OF5RN8/P4HTGZs1vb6vPDZF5k4scDq2iF27x3szAStKIveaidFZEBQ0PC8HrkiOEqiKQ1lPrFLiifiMqEIQ4e9V2PGVjX2T3Z53NihuHQSRzV5zUtxb1wRBR5jnyrqz/lEw0PCyGXQrRHGfWJ6KBx0ZSM6O5jWAn7gEoYDpOFhjoySUpKVvUMsO8uRE6OHXYrZNqWpU5y/dJqffOdfoWuCXjhOlBqjGA1AaaTHc7TcIcKI0dJ9hpEDhgaGlpjrAPRkFTOOXKK0S+yHKKlj/dig+dUuxYzgG+UZqoUGr7zyGTaNIVJ5xE2PzhcdBnFI7Jm01h8wsXQRrbtFplgmiE1iv0ek9dCzWZRh4rkeiardRcU+iphQCnrtffIpA7t3l6mJaVoNE4jQ48QKGeYFXRGSTVsERooQibI0tER8gJJJf6AUAZEIiYMwKbCKGmCM0449MosncPbuET1jcuJnkuoXPaKoxxuLsxy4j3j9/DJr+w/Ruw6LzRLBfAfhFVBCIwxbzI7ENALQs4o7azVcv43AJ4qHKCKQIUPXR7QjLG1IL5rElwUKKgAEqWyKth3gKA8hG7iRgTB1YkMms4RMfBdCKMLYQ/CkYMurE9sGF0/Nsr7xmMfNJoyPsjcqmXw/pvWFAa7X440zp6hF++QOBjQDh8mx0/SvDsm+7OAPdMJwgBf3MWyJtDJkJlI8Lc+zslnjxz+4hx+2UCQle+JXPOb/Xw0DtimRerLvK0hUnHH8RD4h4dHVN/GCiCgM6DRrxNk0Kgr48M2fUz1oMDU1jZ/Nsrtf5+6Nuyxmspw+NYsxGFKJI3ZurbI8JZktzeH4IFNTbB11+PyrX+HBnRWaRy2sTIFyRXL72lVml86x+uA+2bE5bAMKuTTFQh7PjzCESzpVJjs3Qb/vILSIuB0mASypELqOCH2ShsLEyOd0O8TDDrX9XdzY5+TEBIahY6ctpsfylIopKmM5BoMunucwd/IUW9t1Vm/fZUZI/v7zzzA6U+HWrcfYqRwf3VthdjKFF8CjnR6dvQYTCxfo96q8+NmXONrbxM6XuX79U158+Qvcv1LDyhUZ1rc42N1icHTAMPRZmBmj3+9TmhghV7CT+mdFcrUW+khNY9h3GOaKFE48xcT0FPc+fp8IwZmz50jZJhuP7jE9P0sQOGys71GpVEDXub9yl0F/QLfbotN12dmrce3WbVq1Gt9+5ikqEyepH9WZHZtjKk5KkhwV4KdyBEqjPHOK9z78hK+//nt8cu0KQS/g9MXnuHn7JmUtJGvqNLt9Lo3NURy6jIxPs7Gzw2DYQUVR0nKnEruf6yermSlNJ5vJsNZuoBAsLC9TzOXRDUE6V+TFL77I7PwUkVLYGYt8eYwLly5w69ZtWr0uGx/c4OCgzr17D9jZ2+Mb58/hCsHHa3uUczaj/SJhrJCxRkrk8GLILpzEHw4p2ikOvQxmtshvf/N17qzsoKdSiWuieojnhTz7zCX2Dqq4/R5hHGBqEqHC5ClAJaFIFUZJ4DNUbN69msiuvCGe5+F6A3SSbQxByKcffsTG6gMmpsZYOLPM+voau9tbtFsDzGyB7d0WN2/cpVqtEjkOE5bkay88x8SJSXZ2angHLWJThyjCUgFd16dqp8mOLzK0yvSxwU7x9Odf4c9/co1Xf+vzDIXOBx/cZOzEeR699afYT/odFKBkiD+MiaVASg3DNhn02rg9SRSHaJpC1xWtWhWVL5DKWCCSV2brjx5x9/pN3vvTn7J06RK+1KjXm+wfNNja3iVWsLW1TTzoUSpm2e85rG00mShlmK1kKJUsRisVtMMOba/NTDFFHA1Z3z7icDMm3RgSRBBbRU49vYw+5nFr5R4/+PN3eOa5VzhyfEZ0hR04GLaJHiVVyoWZEbb2B6QnbdrKI5YKJWMQIbqmJbchIoZuk1B4uCJGbfQIzzVImaeRhs2N/iG+KFKspdCnXfJlgTaRQSqbYfQTFDGalKgoYhgM6M1fwHZ8jNaQlGFiYmH1A/p9MGKXQvEsj3sCP1hjbOZFbq7sYNq7tJ2kb8E0JcMYtDCxEequh8orNEuQnkjTUy5KRCgBpgyINR0pk2C0HvXxBi2GfpMBAXxtmvJDE+8zLv0TIX7Zpe80GRmZJt1MYdR0+l8y0WUqqcOen0OZIW44jtnzSAmHjO9hmBa6F4OW4UP/A2JconiQ1FUrhYx1Gs0eE2Wd1jBFzouxlYdh5NCHfXSZYnJumvudffRy0t431BRCRESEyfeBpoNUuO0WlhkSBH0avT2mxsYJrBBLKS4/+ojMSAVp9Tj66pCp6ymCiwbj4ybd0zmM1Trzxjw5YbH20j4FTZL2QpRbp6LrtGpt5qbmiH3Jtbd+gYo9IuUD3pNaekEQDznq2GgGZFMCLw4widGkhnBd9AmbIPKwp3L0wyHoYTJMi+R8kZpAEwrl9PDdHn7g4Dj7FEaXuTdooGeGGLKPITPoIiZ6XlF5KPGfN8h2LTpLFguDBsHUIpVHkvXf61OKJJbuYh45iSdCBehKw2kOeHi0hVN5gU77BmHcRZHs2P5FQdmvdRgQWtK4BCSVsUAUh8QoDEPgu0M04OnlEYToEwU9hGPjOybXP7nBFfcTzj13ie7A5Ztf+xKx41DtdGl0u1wQp7j4/Bk2rjus7FWpdXxOvzjPxdkzXL1ymZnps2zVh0TNHp978UXe/Mm/4dDLMjU9x9T4BLt3rtLrdNGzJYJBi/t3b7Fy5xamlSWTyWCmwQyH6DJZ/5IyOXhAEAQBQki6R1Xu1jZpHbUIfI92s04uYyOiNHeufcL9W7ep7lU5/ewlHm7VeLT2mOphg9GRAqeXFtHR0N2Y55bmeO/uHZaXK5TTirZv8rgTUMhMMjF9gtJIkffffod0YZRMykQ3De7eusn8yTPcuHkLadpUimVKk2O0BgPcgUOuUCJlKtJp60lSXiG1JDgWhmFiVQtj9GiA2zx4ElQTVMbHsDSJnUuzdHKak4szFCtlDDPF/Xv3mZmdoVrv8ODxIfdW7sOwzzSKf/QP/zYPPrnBWr1GNgy5vrqLIoBY0Pd8gnGbz37h87z/4V2WFk4SDpq47TauD1euXePsxWe5886PWDpzlpHpKQrlMv3bHfrdDOvrj1icHseUSXOelBBGiUzFMC02Vz4ls7HKwd4uQRAThBGO62Gj4TsDrn3wHgebs0zNjaPmF7jy8WUOD2pMzszyeLPKnVv3qR4eYhoWX3v1Bf7e3/h93vzXP6JVj4iJuL97iJ32yVqKo4bH/ZqHfdBmcmqCc+cmydkFvvv9H/B3//AfE4QO+ew4Jy+kuPNolYlMRNE2ubK7y/yZi4gnQcE4jpEieQLyo5iIROBjRyHDTgvPB9sI0LOCwdE+cTqNnc2gIhtCk7rn8rMf/Tnzc3Ocfemz7B/UabT7fPrpdRrNFrYmeWl5mbN5m6JtMj07ydbA4VatAxYot0clK1mcstiuhtxc2ee9dz9gZ2+fh/dLFEZGaQwVzcYQeXeF3SOP3//23+GTq/cYGRul9TBAaqAbRiLnET5K00nZGTJCIqWGrkOKiJOVPGEY40cDfH+IFxgEVpo4lcE0LT56/wpOd8jqjQfkz19ga3sXKQWNRp1GvU2+kOar3/oyf/DXvsHERJGPf/4J7/5vH5CWIdnSNDPzZU7Udjk9XyCfVYRRFl/oPL7XgHabf/un/w7TkuRzY/jC4vqtDWYXlmi3W2Rlm3x6GoVLOlskru4gJEz3LHajJn3HZGqkRLOxRyRChqGHmcuiZ0sITRDbCl1pHCxUGdmOeVHO88lRiJ0PqXdg/o5CTrg4JyNM3UDoETISmHpS9PXSV17i4599hFIBflpHswY4hkdoCyzNwlUBsmzjDByCvU8h0tHdHi8tPc3Kh2uo9RRprUMswcpM0BgMEH4LqQkiT7B8+iQfbAWEXo1K2qTbaeCikJaLSI1iGBkQIWHso6SHJnooPcKYCDj/+QEr/zLLXtjF9HsEoc3MkUQLQ3q/DRKDgpnUaufSacIwwg0NsCTDqEOYTjIhppKkYoUfrz85DEAz9OQWAsFg0KTMBYayQyY9Rs6SuFGfnNdCiA6d+030aR2/36cwMoLo1vFVhKNcrGwKKYvomkFkRSACAr+NihsEFNG1IYbos9+sY2aSamPLtml/ycfey3N3WzEmU9QODMY+axC9DLl2OjnAY4XruAz0FJFpYiqBHkju/PBmshqP/D9yM0okgcBII2UYZOwZPM9D4CPdZE3+M+PnuUUPZ7DL5EiZo6M9AmJ8eqSyI1jZMkLEBNJHaQGh30AYHlY6wou6BKLH0tIUq1s9DN3En4kxJvNkHsDOXsyIZrDdLzHuagRv6KQ6KUDDiXr4hkc6rRMLA5MIYyRNr6t4fP1t+tE9iDUQ6cTRI/4/eE3wuDZkeTyNIEBXAj1tMDVTYnZ+FKlLjOwIbrdHNmrjDgNcxyeMwPM9tNinFSiuXb6KZeoMalUqJ89Sa9WYy0Oj3sKJQJpjFE6fwI5C9GyesqWIg5gPrl7h6c++Sj43woc//S7FbI7lC5fYbPTIl8eozC3iCYVtpXnw4B4LszYnpiqUSxXu3F3l8OEWo9Pz6E+0SlIJ0CVCJp6EWJNoRAxdh+kxSWaySCA7qKhDPEwz9AVRpPOL965yWG0jTZug2+Mf/Ud/hzde/zL/w3/3Ha7dWGFs0+T0TJmFxQUe7G6y3VB0nQF3H+7wW7/7RSamp0jZOWSuwfpBFRlFfOt33uB/+aP/CTPyGRsdZWb+VJIwr+4wkCaV8TGCWLBy41NuXr+FkAJDkLSzSYHnR79UcT68+R6uE+D7LmEc4vQ6kLKQKuLqBx9w59ObOL02C2cv0HEDdvfqfPTxJ7TbPc7PTvI7l55hciyPoUImnjvPw41NDgYN8nlBqZyi1oVbdw+p7teZPvMSg1hx6ZmnqG1vUsmOsN3uMBi0OdjdY3L+FAEGW4/XaJyrMmw14MQ05VIFQ0ikLpMnijg58AHiKKS6u41U4LoDiAN69QPIF1BRFoEiDPs43T4//J//mKXl86Qnp9jfP+Dh+iY761vEMXzrK6/wlRc/g+nUqX3wAefHKvRaHRrdQy6eKZNK5QgCyJTg8vYqqpfm3r2HnD53iUcbuzx14RK/+MV7eMMerf06ixeeZW5hmZ3rH9Bvd8mUKqQLZYTQk+4IoYNMglyxijCzNumKjW5YhHGFwBtwctxEBS5RpOF4ffzQQTk2A9dAGjaPHj5m4/EmrYFLYxCwXzvEMFK8/NKzLKVTVFwXJSCTtvCHDqfzGQ4Ladbquzx/dpS8pVAqZk7L8pRnsPJ4jVgJbGuKrc0tvvyVbyBFlU9uPcAwily+fI1et0khX6ZJIh3TNEmhaHLumTkK5TSWZdDpBASiQL0ZMRLtk9KGRCSyIs+NCZXCC30GQxenA51uF98LOGgcMe72KVuC9a09NMvie9//ZxSLioUTM7S7bf7kBx/y5psfEQc9co5Ord7iD7/99zhE4O58QBgrglCjUB5BRvsImWQalpef5tbKJkZ6hFMnFzl36TPcuv4Jb7z+H/OdP7+CQDL0D1k0En1rd3dAKR+hRJNOt8XpUo7HtRp+7CF8HxkqJBoi7qBiiMIhzimD/tke+3eqjDTKOG5E65kuRtZECyUxEoViQgrOnygx/9oXqIyP0B+0ULt1kB7CElgpCykhwCdEEYdg5g3iMLn+10o+/3b1fdInDOx8mi+d+QPubza5v3sIQMnoorTknfpou0PKDlBxDaEyFLQBkTdEaCF6ZKCFXtJ/EriIKCT0oZyvIKMh2x2f1rNDvDtD3B2wdJ/hCx6ybKLCpCsmpWvEgKHpya1vx0FLJ42SlkyU4xEBvej/mlCffuok258+SP6LNEjPjfP1U1/mj9+8SqgMpOcxZ/QSxW+k85nZZe4crtPv9VgoFdg4OMBXDsLMoesKQg0ZDpJQceCSSeUxNR0R99FkwO52k4XFEug2gfSxDAN/0eFhbwf7TJGHcpvxUycwDZ2UpRHGEZqpk03bQEwsfDwhcJ3ERZ68logBxbOvvcC1n19GKo3XfutvYsgUf/LxKkrBlGiSMT1Q0Lh1CzPuoYkOvV6X6YzOfquO1Hz0SEeGyU1b7A8RsUvguZRGKgh/iIgC9tb3GS2WEHFEFEkMoeOnA8LnBOvUyC9Ps1c4onRhHFsXmDoITSVB84xC0yJcAUEMhAJzuUD1409QhAgZEAmfxYunWbv58Nc/DNzeqBL6Y/zuV+eYnS0yPp0nnzXQDPVLnauhj6DJCiKUDHse/Y5H66jPb78asN90Oaj32Nhps12tokyDUk4n9jXajQZDY5SMBaeWnqHaaOM7ip+9fRs9lSOV0+l1B1iaxszCIgfVI6Rh0jiqMnRcup02cehQHp9mrJSjUrSIvT6GyvDyCxd562PFh9sOT8URRigQhkLTBLZtsrw4TXmijNBNhuEMpbgGXh8vVMR+RKQSh3h3EOIHIffurGCYOsVcnqv/7i22NnZ5852PSUd9Jhenud9qMTn1NJe+tMwPv/8D7j7eIo4kpZEcyhvy7i+uUJhYYGRumUo2w43LH1IpZJiZmeGg5VAYm6HeapEuj2GlMnjOkN39HSbySZlNbIjEwQ6EkZNoiAGTmGGjRSwUE6Mm4xWLcHDAYKhDOoe0LQJnyE9+/C7T1x5Qmp1lc2cXpQT/yT/+2zx69wpbvsutG/uMFfKM5iymxktkxvNsbK8jfcVOy+HR7iEXnn+F7eoRf+X136He6/HR5dtUxkYpFgu4bsDB7mMW5md49923GJuYwh0MmBsfpZyy8Hp9mq6gQtK7H6oouemwEsWrFAqpSU4v5LE1CxULgqBLFA6IBxaeNHEwWH2wxu7OAc+/8goP1tZpdnu89qXP8s3Xv8q8XeCn3/shaQIqlVHsTIoXzy5Q7evEYY+eGzHwU9x6uEm71eLMhadp912uXnuArhv84T/4B/zsnXe4+ckjJqfnuXb9MnMnljhqNxjGklqtRrN+hOcOSIkITRNkizrj4wUWT02SyltJBbVuMfA0jnaOKMg2jmsRR4owNIkCSaig7w7wfIcwTOqyH9y4ycz0NM8tzDAM4T/8D77N43c+YPXGKiXL5rA9QAwCpoKIZ2amKBc0Wp06ngWxZtP3DfrdGq2jI0bHJ9GlQbvVp1fv4XqSSmmUL772Ld77xc/xui26nQFffeEklZE0xbJBJpdG05LPJIDZ6RSamdg91XCGTmNAp9Gh3xkmIbcYUhFYRuJX99zEgnnUavHWz35OKZ9ncXqaE2dOUDAVI8UK3/tff8qf/dn7DLodouGAwO/TLBmkozw3b99jrd4nF2XpD1w6Q1hdP6DreLR39/jP/ot/wttXHtFp1qHW5vf+xt+FcEC5aHNmbhoVJ6G7rO4iwoBIJrWtIw50jKRwbMft8sxTF7nx4B6aYaEREvn+E3tjRD6fx8rorOxEeFM9NrwuZjlNnI3BVuhmUjKlZEz+cMCJxUlGTo+hCcXnfvdzvNDo8qdGgC8EURQRPSlpMjUDPwwo2QUyeYWl+zhDjVbcwJ7Oo0optu9XSVkpVBwgVUQ2A42ug2VlqW3vsTQ1Slcohs6AufExjF6X5qCVyNpDFxXHGJpAmDqilMNIm0RaxHZNR9gS63Saw70m2ayBmJBIK/kcghhZFMgn7alEyfefEonESqgkJ2JoGlokEgV0GCEQzCxOsHN7DYIY0xBEMykmR0ooldwWyDhE18HxYzRDEG7W0AqCmJidXovnP/Mc1+7eQtMtNCKiwMc2BVEYkc2nEUYGTIHUAoI4BttiEDqkUsnrQ80AzRSkKzZN/4jyfAmzpKMLgan0pDBPU4RRSBQFBG5AyraIlc9COc1mZ4iKQLd1SieLjC2OUV+rs7O9z2uvfJ7w/TvoWkw5n2LQ7iK1DMoPOD869kQRDZ5tsHBihq39PTQpIXRBKXQDEDqZUhYza6JEhG5KfE2wdVClPF5GagLdlhiGQEnILRTpGV1EXqEVkr+N8CPCOHzS/xEy8IZoQgASQxpk8pK0KWlHSe4nVTI5/+WzNBtHdO+1fr3DQKxiNpo9xudGyRZ1UOAFAZbU0bSkWx4hQOoYaYNyNk1sxzw6NJmdH1Iqu5w7NcrByQ5REDBaHiFUsFPtcdjpEoTbzJ48RXX/iHff+5jzT7/CwMzTrTU4e+YsOxureEcG6VyJlK1hawpdRERRkLSTZdJ4vR7d5j53qw3CMOS55y8xOupy8dwpLt97k15gcnJCMjaeYmZ+jInJAqmMQIoYK2Vh6gZwAsIQrx/RPGrT73i4Q484HHD65DjN7pB2z2OvXufgnfeZ2NpkxDawTcna9h5GKkeqM8BOZ7j00ito5TFeefkLFAqj7GzVKE1MsnL/Ec+/+hqTk9Pcv/kheirDzPwi3XiPdrtJJl+m3euwtbVOOZvjxMIkObPP9HiBO4+Sd0BSSiJi7IzF6GSFUOXQrAmGnSNOjVvEUYxSGkEAKI+u6xD4Ctf12d3awnOGzJTLZMoFPvfcWVYu3+JOs0Xcq6MMRbuflKc8/ZmXuF3rc+POQw6bfdwgYnJ8gkDFrG8cMFIqkRub5cHeJkQxr7/+e/zRv/gfaZiCmZlZYj3DzvY2/tE++cooeipHcaQIvcRFUCjlmJouMDpVwvF8lFmk1/QYS9VxB0lSPw4iwgj8IMYPhzihJIwi2u0OH731cyrFMufLZT7/7AUWF6d567s/pdbrkzYMhvVkr3e6aGBFGS5v1zDSGR6sb7NXa+IHAaZp8s1vfZO3rtwkX8zzs5/8GD0SLC8usd3ooAnJ/s4GM4unCTSdQb+HimKWL0wyM2lQLOvk0xamJdGNJMxlmBJdV0xqsDQ3hq4m6PUc2rUOtWoLz/HQAtCESWQLpIAgjKh3+7T6DzDWHpFJZ/hv/8v/itmzl7i+36QoI8qGyUihxLDvsjCRZyqbZ3VjH08ziITP1k6Vw2abMIiwVESpUuHbL7zMux/fIooilk4tYeNQTBmsPqwxu3CKmZN1UlbS469ij0jEaJqOZgh0XWCYyWfSSwbF0WkE0wy6Q+oHbZqHbQY9B01LXhuOllNUmx5+FCOIOGx1OWy1ub26yvWPPqV89hKNbofA6yMGHebHMsyNjxIIja7S6AwdLn3meR4+qvDo7R+zedCi3e3jOT4vfO5lDhou8ycX2ZaKkVIJd9hm0Kyiuy7/6nt/RKqzi4wDsnkbJSBdyEDsEh7WKVwaoeMNCQnZODrg1OwUW70OGPGTZlKBigV6pIOuiISgPFbg4LDPSMUmFgNiXRBqAqEpDCRvX7nK5OefYegPEWGIFBIn8MhlbAzdoOc5hCo5MKQGptBwoyFWRmGZFp4KUNLASKcojmZxr+2y8midjOchVQipMUSU5AZEFPP8yBxv+QdIqbNdq3FuaQFne0hkJlfcSosStXwUYYQ62BDEgpgQJQVGxsAuWphpncDwibUY4hihwHUS3weRIn4iLkvrBq7n4bte4kWRMVLGvPD1C2xefkSn72JkBRe/fp7OJ1tMT5SYmq5w9f03MVsb2GGMJTyoFLFNC10Ienv7XHrqWa7vPCbWFLcfPeTppVOsHOyAqUAmq7oaAhUk9lZhJq8Sg8gnWzEhFREYPhgquZgTkvHpEVrVDlNz0wzdpPoaLSaTtslbqURH7gkMFIYmMNPwn/+tV3m3ExOaBjvdBlrB4AtvfIm1d1YwvAY//P53yLV6yVZAZgyCiGw+BYHDFBYtFMqA1rBPKZ+lMlKgK0BpAiUSN4oeK0zbQFmJ0lypGC1n4OsBvvQwDJ1QA2RiCS0Ws+zXDsjYNkPHRWga2VQaDdCFxsB1MJ6IAhESTTnM5dOc/9ZL/LMP7xCEivNfPEdhrMTrf/gG//1/+t/8eoaB//NqgmUY/MlPt5kZk8zNFBkfS1MqWuTzNpYNtiUxdFCx5PKNI/YPGgxaDcpZg5SuSBswljfRNJMgjuj3IcDi09t3OPOUycaDB2ytrnJ6aZnLNz7ixNIzzC9f5MGtT1hcnOPB2iMWx2foPO4xGHbAc+k1q/Qbh/jZFH4kODpax3cdMukUd1dWWTq1wMTcHLK5zafVJm+8cYmMZWFYMQN3SKAEmg5+HGPqSXGEoetYIybTpSy6tHj0uEs/2OYLn0szGDq0OwH3N49Qrs/TZ0qMltIMnIDtus9+3SWVL3Ll9hanl6Z46dVvsLh8ge//i+9RGR2j6wRMzoxzuLtGdeM+fqzoNqrUavs8uHedzE4B28rgNHcpFQukDZPt9cfMjkJ1e4PATdr4JifzLJ6cozJZwDRA0zU03Sb0dJQbMWgNaDfav7Q3pm1F4AZEQqfn+vR399jYPUATcPv6HQrziwz7Rzw1nWZuRkEUcnf7iJkLETPLl/jXP3qLvutgmzaz02NMnn2eGzc/5c7lq4ycXKY8vUTetllZuUUYDJicGGOv0UembXKZLKgpQgR2JkPtYJ8vXzzJ+GSOYllD0xRKxIyZqURlumCiM4czHOB1JLWDGsFg+GT1DWTsgtTwI59w4NAd7LO5d8jN+6uk8yVGJyfxB30uLRXZPzQwZAanozE2nmG+Mso//+nHNDp9YgmmEhRLI2zsHjE+McUrX3iRyx9fYfvWCgtL50jpGuOVIkYqS795xM0bVzl16mTSjZCSKD0kCGMGjsL1wbKT1x8yijFkjGXEaHriXTALFhP5SUYXRhm2QlpHHap7DTo9j5dfWGBrb8BBs0ev5+BFIX2/R/36PVb324QiJrAUM8szVJt12l0DM6tRmDjJ81+e4/s//BF3VzdxvQDP90kZJgPfod93EUYBuzBCv9/h+ReeZ/3BLew4ZnZ2modb6yzm6xQLNpmMiWlo2HqMZcdITaKbMYYrkFpiwtMCH6lFCFswfnKS8bkZXMejttdlfWOLL768xGG9T63lsr7dptcZIIBBFLJTa+NVDuk2m5QygvNLJbK2oDcM8aWDNZLl5pXrzJ+aYmevStszuPdwhyDwmZtdYHRqko2tdUyrxF/967/D6oOHrK7c5GDtMaX8BCmnR2bQQyLQLEmgK7Kawht6eGHE12aW+f7dD1EmuG2XMMpT0DUOnR7xk4KwMA6Iwhj5ZDVx0PQp2BlyhoUIY4TrIyK4tHgayxf4Fw4wC3mC3jAJU0cxoVGiUa2i2yaWrqGp5Dux0+9jaDpKRMTRE1W3qRGHMQaKbMam+WCFgqsYej66AcOegwojUCGB53P78m1mXjnJRqsOhNxYW+HSzAI3G7tAnDhZ4icyqzAifPKzjGOkgtjzmCqPIHWJcEJkFCO1pC8icHz8MNFi/4ViuiPA1HWkUpjoNHt9BIrcYpl/evpr/Mtbq9hFm3xsc/GZWdKxwcjoKK3aKpl+D5QiZer4roedMwi9AEGIvd/B9YeAIBYhKzse5yrTXGvuPHktEhGpiDiIklxaYBBFMSJWjGQspKkjXQ8tFhAq0E0q2SKRo5Ex0wy7Q6I4IgpDQqUYan1M3cSWGrFm0ex1OWVb2LZByTLYivpMTk3jeD0iU/LG3/wK1//5exR8SXfYR5MSp9NLNN9xSOQ7HB21+OqpL/OTezdAxDw82GJpdJLGsEngRyhdJS24UdKGGgTBkzMVKpk0tiXRvBhd+Ig4GWrQIWuk8FIFMpZOvzckjMJkawcwNBNLl1jCpOsNiAIPTUr+ysQkP2yu8NzvvkAswLJ0LMNirJD5v53j/08I9SssIe7u7jI7O/urzA3HHHPMMcccc8y/Z+zs7DAzM/OX/v5XGgbiOGZ/f59cLvfLqtNjjjnmmGOOOebfb5RS9Ho9pqamki6Uv4RfaRg45phjjjnmmGP+/8tfPiYcc8wxxxxzzDG/ERwPA8ccc8wxxxzzG87xMHDMMcccc8wxv+EcDwPHHHPMMccc8xvO8TBwzDHHHHPMMb/hHA8DxxxzzDHHHPMbzvEwcMwxxxxzzDG/4fzvAXjWX5qi1OQAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_augment_2/1719298626_tango_video2audio_clip4clip_augment_2_best_steps_300_guidance_3.0_sampleRate_16000_augment/SNanDCQuatA_000189_iAEDIntNw4w_000010_6.wav\"\n", + "show_mel(file, save_name=\"2_wav.pdf\")\n", + "show_video_frames(\"../data/video_processed/video_gt_augment/SNanDCQuatA_000189_iAEDIntNw4w_000010_6.mp4\", save_name=\"2_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3407931/2218442382.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(720, 12800, 3)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_augment_2/1719298626_tango_video2audio_clip4clip_augment_2_best_steps_300_guidance_3.0_sampleRate_16000_augment/YszbOz38VFE_000020_h_jJj2PbFRI_000009_5.wav\"\n", + "show_mel(file, save_name=\"3_wav.pdf\")\n", + "show_video_frames(\"../data/video_processed/video_gt_augment/YszbOz38VFE_000020_h_jJj2PbFRI_000009_5.mp4\", save_name=\"3_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3407931/2218442382.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(360, 4800, 3)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_augment_2/1719298626_tango_video2audio_clip4clip_augment_2_best_steps_300_guidance_3.0_sampleRate_16000_augment/62Sboyh19LA_000182_I7tuwRK0L1w_000220_4.wav\"\n", + "show_mel(file, save_name=\"4_wav.pdf\")\n", + "show_video_frames(\"../data/video_processed/video_gt_augment/62Sboyh19LA_000182_I7tuwRK0L1w_000220_4.mp4\", save_name=\"4_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3407931/2796252854.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_augment_2/1719298626_tango_video2audio_clip4clip_augment_2_best_steps_300_guidance_3.0_sampleRate_16000_augment/VR1UJLGULOQ_000053_wiyojlC9xbI_000030_4.wav\"\n", + "show_mel(file, save_name=\"5_wav.pdf\")\n", + "show_video_frames(\"../data/video_processed/video_gt_augment/VR1UJLGULOQ_000053_wiyojlC9xbI_000030_4.mp4\", save_name=\"5_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_3407931/2796252854.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_augment_2/1719298626_tango_video2audio_clip4clip_augment_2_best_steps_300_guidance_3.0_sampleRate_16000_augment/OnYPsGAjYBU_000560_tPnt7aeBK44_000120_4.wav\"\n", + "show_mel(file, save_name=\"6_wav.pdf\")\n", + "show_video_frames(\"../data/video_processed/video_gt_augment/OnYPsGAjYBU_000560_tPnt7aeBK44_000120_4.mp4\", save_name=\"6_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/99652230.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1720425300_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/0 (3).wav\"\n", + "show_mel(file, save_name=\"ood_0_wav.pdf\")\n", + "show_video_frames(\"../data/foleycrafter/0 (3).mp4\", save_name=\"ood_0_video.pdf\")" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/99652230.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1720425300_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/0.wav\"\n", + "show_mel(file, save_name=\"ood_1_wav.pdf\")\n", + "show_video_frames(\"../data/foleycrafter/0.mp4\", save_name=\"ood_1_video.pdf\", frame_rate=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/978430674.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(80, 1280)\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1720425300_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/1.wav\"\n", + "show_mel(file, save_name=\"ood_2_wav.pdf\")\n", + "show_video_frames(\"../data/foleycrafter/1.mp4\", save_name=\"ood_2_video.pdf\", frame_rate=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/99652230.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1720425300_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/case1.wav\"\n", + "show_mel(file, save_name=\"ood_3_wav.pdf\")\n", + "show_video_frames(\"../data/foleycrafter/case1.mp4\", save_name=\"ood_3_video.pdf\", frame_rate=1)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/99652230.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1720425300_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/2.wav\"\n", + "show_mel(file, save_name=\"ood_4_wav.pdf\")\n", + "show_video_frames(\"../data/foleycrafter/2.mp4\", save_name=\"ood_4_video.pdf\", frame_rate=2)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/99652230.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1718173660_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/wolf_howling_at_the_moon.wav\"\n", + "show_mel(file, save_name=\"ood_5_wav.pdf\")\n", + "show_video_frames(\"../data/sora_cut_2/wolf_howling_at_the_moon.mp4\", save_name=\"ood_5_video.pdf\", frame_rate=0.2)" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/tmp/ipykernel_2189924/99652230.py:56: FutureWarning: Pass sr=16000 as keyword args. From version 0.10 passing these as positional arguments will result in an error\n", + " ori_wav = librosa.load(audio_file, sr)[0]\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "file = \"../outputs/tango_video2audio_clip4clip_4/1718173660_tango_video2audio_clip4clip_4_best_steps_300_guidance_3.0_sampleRate_16000_augment/otter_surfboard.wav\"\n", + "show_mel(file, save_name=\"ood_6_wav.pdf\")\n", + "show_video_frames(\"../data/sora_cut_2/otter_surfboard.mp4\", save_name=\"ood_6_video.pdf\", frame_rate=0.4)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "audioedit", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.18" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tools/torch_tools.py b/tools/torch_tools.py new file mode 100644 index 0000000000000000000000000000000000000000..49e53d7c3c3e6df3e60ddb9503679ee924bd3188 --- /dev/null +++ b/tools/torch_tools.py @@ -0,0 +1,296 @@ +import torch +import torchaudio +import random +import itertools +import numpy as np +from tools.mix import mix +from PIL import Image +import cv2 +from moviepy.editor import VideoFileClip +from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize, InterpolationMode, RandomResizedCrop + +def normalize_wav(waveform): + waveform = waveform - torch.mean(waveform) + waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) + return waveform * 0.5 + +def sinusoidal_positional_embedding(token_sequence_size, token_embedding_dim, n=10000.0): + + if token_embedding_dim % 2 != 0: + raise ValueError("Sinusoidal positional embedding cannot apply to odd token embedding dim (got dim={:d})".format(token_embedding_dim)) + + T = token_sequence_size + d = token_embedding_dim #d_model=head_num*d_k, not d_q, d_k, d_v + + positions = torch.arange(0, T).unsqueeze_(1) + embeddings = torch.zeros(T, d) + + denominators = torch.pow(n, 2*torch.arange(0, d//2)/d) # 10000^(2i/d_model), i is the index of embedding + embeddings[:, 0::2] = torch.sin(positions/denominators) # sin(pos/10000^(2i/d_model)) + embeddings[:, 1::2] = torch.cos(positions/denominators) # cos(pos/10000^(2i/d_model)) + + return embeddings + +def pad_wav(waveform, segment_length): + waveform_length = len(waveform) + + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + else: + pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) + waveform = torch.cat([waveform, pad_wav]) + return waveform + + +def _pad_spec(fbank, target_length=1000): + batch, n_frames, channels = fbank.shape + p = target_length - n_frames + if p > 0: + pad = torch.zeros(batch, p, channels).to(fbank.device) + fbank = torch.cat([fbank, pad], 1) + elif p < 0: + fbank = fbank[:, :target_length, :] + + if channels % 2 != 0: + fbank = fbank[:, :, :-1] + + return fbank + + +def read_wav_file(filename, segment_length, tgt_sr=48000): + waveform, sr = torchaudio.load(filename) # Faster!!! + if sr != tgt_sr: + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=tgt_sr)[0] + else: + waveform = waveform.squeeze() + try: + waveform = normalize_wav(waveform) + except: + print ("Exception normalizing:", filename) + waveform = torch.ones(tgt_sr * 10) + waveform = pad_wav(waveform, segment_length).unsqueeze(0) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + return waveform + + +def get_mel_from_wav(audio, _stft): + audio1 = torch.nan_to_num(torch.clip(audio, -1, 1)) + audio2 = torch.autograd.Variable(audio1, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio2) + return melspec, log_magnitudes_stft, energy + +def wav_to_fbank(paths, target_length=1000, sample_rate=16000, fn_STFT=None): + assert fn_STFT is not None + if sample_rate == 16000: + hop_size = 160 + elif sample_rate == 24000: + hop_size = 240 + elif sample_rate == 32000: + hop_size = 320 + elif sample_rate == 48000: + hop_size = 480 + else: + raise ValueError(f"sample_rate wrong.") + + #print("target_length", target_length, hop_size) + #print("target_length", target_length, sample_rate, fn_STFT) + #for name, param in fn_STFT.named_parameters(): + # print(name, param.data) + waveform = torch.cat([read_wav_file(path, target_length * hop_size, tgt_sr=sample_rate) for path in paths], 0) # hop size is 160 + #print("waveform", waveform.size()) + + #np.set_printoptions(threshold=np.inf) + #print("waveform", waveform) + #f_out = open(paths[0].split("/")[-1]+".scp",'w') + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + #print("fbank", fbank) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + #f_out.write(paths[0]+ "\n" + str(waveform.cpu().numpy())+"\n") + #f_out.write("audio1"+ "\n" + str(audio1.cpu().numpy())+"\n") + #f_out.write("audio2"+ "\n" + str(audio2.cpu().numpy())+"\n") + #f_out.write("fbank" + "\n" + str(fbank.cpu().numpy())+"\n") + #print(fbank2) + return fbank, log_magnitudes_stft, waveform + +def get_wav_from_video(video_path, segment_length, tgt_sr=48000): + video = VideoFileClip(video_path) + audio = video.audio + sr = audio.fps + audio_data = audio.to_soundarray() # 441882 * 2 双通道 + waveform = torch.mean(torch.tensor(audio_data, dtype=torch.float), dim=1).unsqueeze(0) # 变成单通道 + if sr != tgt_sr: + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=tgt_sr)[0] + else: + waveform = waveform.squeeze() + try: + waveform = normalize_wav(waveform) + except: + print ("Exception normalizing:", video_path) + waveform = torch.ones(tgt_sr * 10) + waveform = pad_wav(waveform, segment_length).unsqueeze(0) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + return waveform + +def get_wavs_from_videos(video_paths, segment_length, tgt_sr=48000): + wavs = [] + for video_path in video_paths: + waveform = get_wav_from_video(video_path, segment_length, tgt_sr) + wavs.append(waveform) + wavs = torch.cat(wavs, 0) + return wavs + +def wav_in_video_to_fbank(input, target_length=1000, sample_rate=16000, fn_STFT=None, waveform=False): + assert fn_STFT is not None + if sample_rate == 16000: + hop_size = 160 + elif sample_rate == 24000: + hop_size = 240 + elif sample_rate == 32000: + hop_size = 320 + elif sample_rate == 48000: + hop_size = 480 + else: + raise ValueError(f"sample_rate wrong.") + + if not waveform: + paths = input + waveform = get_wavs_from_videos(paths, target_length * hop_size, tgt_sr=sample_rate) # hop size is 160 + else: + waveform = input + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + return fbank, log_magnitudes_stft, waveform + + +def uncapitalize(s): + if s: + return s[:1].lower() + s[1:] + else: + return "" + + +def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1000, sample_rate=16000): + + if sample_rate == 16000: + hop_size = 160 + elif sample_rate == 24000: + hop_size = 240 + elif sample_rate == 32000: + hop_size = 320 + elif sample_rate == 48000: + hop_size = 480 + else: + raise ValueError(f"sample_rate wrong.") + + sound1 = read_wav_file(path1, target_length * hop_size)[0].numpy() + #print("sound1", target_length, sound1.size) + sound2 = read_wav_file(path2, target_length * hop_size)[0].numpy() + mixed_sound = mix(sound1, sound2, 0.5, sample_rate).reshape(1, -1) + #print("mixed_sound", mixed_sound.size) + mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) + return mixed_sound, mixed_caption + + +def augment(paths, texts, num_items=4, target_length=1000, sample_rate=16000): + mixed_sounds, mixed_captions = [], [] + combinations = list(itertools.combinations(list(range(len(texts))), 2)) + random.shuffle(combinations) + if len(combinations) < num_items: + selected_combinations = combinations + else: + selected_combinations = combinations[:num_items] + + for (i, j) in selected_combinations: + new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length, sample_rate) + mixed_sounds.append(new_sound) + mixed_captions.append(new_caption) + + waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + + return waveform, mixed_captions + + +def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1000, sample_rate=16000, fn_STFT=None): + assert fn_STFT is not None + + waveform, captions = augment(paths, texts, target_length = target_length, sample_rate=sample_rate) + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform, captions + + +def load_image(impaths, crop_size=384): + imgs = [] + RGB_mean = [0.485, 0.456, 0.406] + RGB_std = [0.229, 0.224, 0.225] + image_resize_and_crop = Compose([RandomResizedCrop(crop_size), ToTensor()]) + image_normalize = Normalize(mean=RGB_mean, std=RGB_std) + for impath in impaths: + img = Image.open(impath).convert('RGB') + img = image_resize_and_crop(img) + img = image_normalize(img) + imgs.append(img) + imgs = torch.stack(imgs) + + return imgs + +def load_video(video_path, frame_rate=1.0, size=224): + def preprocess(size, n_px): + return Compose([ + Resize(size, interpolation=InterpolationMode.BICUBIC), + CenterCrop(size), + lambda image: image.convert("RGB"), + ToTensor(), + # Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ])(n_px) + videos = [] + # for video_path in video_paths: + # cap = cv2.VideoCapture(video_path) + cap = cv2.VideoCapture(video_path, cv2.CAP_FFMPEG) + frameCount = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) + fps = int(cap.get(cv2.CAP_PROP_FPS)) + if fps < 1: + images = np.zeros([3, size, size], dtype=np.float32) + print("ERROR: problem reading video file: ", video_path) + else: + total_duration = (frameCount + fps - 1) // fps + start_sec, end_sec = 0, total_duration + interval = fps / frame_rate + frames_idx = np.floor(np.arange(start_sec*fps, end_sec*fps, interval)) + ret = True + images = np.zeros([len(frames_idx), 3, size, size], dtype=np.float32) + + for i, idx in enumerate(frames_idx): + cap.set(cv2.CAP_PROP_POS_FRAMES , idx) + ret, frame = cap.read() + if not ret: break + frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + last_frame = i + images[i,:,:,:] = preprocess(size, Image.fromarray(frame).convert("RGB")) + + images = images[:last_frame+1] + cap.release() + return torch.tensor(images) \ No newline at end of file