from typing import Literal, Optional import open_clip import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from open_clip import create_model_from_pretrained from torchvision.transforms import Normalize from mmaudio.ext.autoencoder import AutoEncoderModule from mmaudio.ext.mel_converter import MelConverter from mmaudio.ext.synchformer import Synchformer from mmaudio.model.utils.distributions import DiagonalGaussianDistribution def patch_clip(clip_model): # a hack to make it output last hidden states # https://github.com/mlfoundations/open_clip/blob/fc5a37b72d705f760ebbc7915b84729816ed471f/src/open_clip/model.py#L269 def new_encode_text(self, text, normalize: bool = False): cast_dtype = self.transformer.get_cast_dtype() x = self.token_embedding(text).to(cast_dtype) # [batch_size, n_ctx, d_model] x = x + self.positional_embedding.to(cast_dtype) x = self.transformer(x, attn_mask=self.attn_mask) x = self.ln_final(x) # [batch_size, n_ctx, transformer.width] return F.normalize(x, dim=-1) if normalize else x clip_model.encode_text = new_encode_text.__get__(clip_model) return clip_model class FeaturesUtils(nn.Module): def __init__( self, *, tod_vae_ckpt: Optional[str] = None, bigvgan_vocoder_ckpt: Optional[str] = None, synchformer_ckpt: Optional[str] = None, enable_conditions: bool = True, mode=Literal['16k', '44k'], ): super().__init__() if enable_conditions: self.clip_model = create_model_from_pretrained('hf-hub:apple/DFN5B-CLIP-ViT-H-14-384', return_transform=False) self.clip_preprocess = Normalize(mean=[0.48145466, 0.4578275, 0.40821073], std=[0.26862954, 0.26130258, 0.27577711]) self.clip_model = patch_clip(self.clip_model) self.synchformer = Synchformer() self.synchformer.load_state_dict( torch.load(synchformer_ckpt, weights_only=True, map_location='cpu')) self.tokenizer = open_clip.get_tokenizer('ViT-H-14-378-quickgelu') # same as 'ViT-H-14' else: self.clip_model = None self.synchformer = None self.tokenizer = None if tod_vae_ckpt is not None: self.tod = AutoEncoderModule(vae_ckpt_path=tod_vae_ckpt, vocoder_ckpt_path=bigvgan_vocoder_ckpt, mode=mode) else: self.tod = None self.mel_converter = MelConverter() def compile(self): if self.clip_model is not None: self.encode_video_with_clip = torch.compile(self.encode_video_with_clip) self.clip_model.encode_image = torch.compile(self.clip_model.encode_image) self.clip_model.encode_text = torch.compile(self.clip_model.encode_text) if self.synchformer is not None: self.synchformer = torch.compile(self.synchformer) self.tod.encode = torch.compile(self.tod.encode) self.decode = torch.compile(self.decode) self.vocode = torch.compile(self.vocode) def train(self, mode: bool) -> None: return super().train(False) @torch.inference_mode() def encode_video_with_clip(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: assert self.clip_model is not None, 'CLIP is not loaded' # x: (B, T, C, H, W) H/W: 384 b, t, c, h, w = x.shape assert c == 3 and h == 384 and w == 384 x = self.clip_preprocess(x) x = rearrange(x, 'b t c h w -> (b t) c h w') outputs = [] if batch_size < 0: batch_size = b * t for i in range(0, b * t, batch_size): outputs.append(self.clip_model.encode_image(x[i:i + batch_size], normalize=True)) x = torch.cat(outputs, dim=0) # x = self.clip_model.encode_image(x, normalize=True) x = rearrange(x, '(b t) d -> b t d', b=b) return x @torch.inference_mode() def encode_video_with_sync(self, x: torch.Tensor, batch_size: int = -1) -> torch.Tensor: assert self.synchformer is not None, 'Synchformer is not loaded' # x: (B, T, C, H, W) H/W: 384 b, t, c, h, w = x.shape assert c == 3 and h == 224 and w == 224 # partition the video segment_size = 16 step_size = 8 num_segments = (t - segment_size) // step_size + 1 segments = [] for i in range(num_segments): segments.append(x[:, i * step_size:i * step_size + segment_size]) x = torch.stack(segments, dim=1) # (B, S, T, C, H, W) outputs = [] if batch_size < 0: batch_size = b for i in range(0, b, batch_size): outputs.append(self.synchformer(x[i:i + batch_size])) x = torch.cat(outputs, dim=0).flatten(start_dim=1, end_dim=2) return x @torch.inference_mode() def encode_text(self, text: list[str]) -> torch.Tensor: assert self.clip_model is not None, 'CLIP is not loaded' assert self.tokenizer is not None, 'Tokenizer is not loaded' # x: (B, L) tokens = self.tokenizer(text).to(self.device) return self.clip_model.encode_text(tokens, normalize=True) @torch.inference_mode() def encode_audio(self, x) -> DiagonalGaussianDistribution: assert self.tod is not None, 'VAE is not loaded' # x: (B * L) mel = self.mel_converter(x) dist = self.tod.encode(mel) return dist @torch.inference_mode() def vocode(self, mel: torch.Tensor) -> torch.Tensor: assert self.tod is not None, 'VAE is not loaded' return self.tod.vocode(mel) @torch.inference_mode() def decode(self, z: torch.Tensor) -> torch.Tensor: assert self.tod is not None, 'VAE is not loaded' return self.tod.decode(z.transpose(1, 2)) @property def device(self): return next(self.parameters()).device @property def dtype(self): return next(self.parameters()).dtype