import torch import torch.nn as nn import numpy as np import os from typing import List from diffusers import StableDiffusionPipeline from diffusers.pipelines.controlnet import MultiControlNetModel from PIL import Image from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from foleycrafter.models.adapters.resampler import Resampler from foleycrafter.models.adapters.utils import is_torch2_available class IPAdapter(torch.nn.Module): """IP-Adapter""" def __init__(self, unet, image_proj_model, adapter_modules, ckpt_path=None): super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules if ckpt_path is not None: self.load_from_checkpoint(ckpt_path) def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): ip_tokens = self.image_proj_model(image_embeds) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) # Predict the noise residual noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample return noise_pred def load_from_checkpoint(self, ckpt_path: str): # Calculate original checksums orig_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) orig_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) state_dict = torch.load(ckpt_path, map_location="cpu") # Load state dict for image_proj_model and adapter_modules self.image_proj_model.load_state_dict(state_dict["image_proj"], strict=True) self.adapter_modules.load_state_dict(state_dict["ip_adapter"], strict=True) # Calculate new checksums new_ip_proj_sum = torch.sum(torch.stack([torch.sum(p) for p in self.image_proj_model.parameters()])) new_adapter_sum = torch.sum(torch.stack([torch.sum(p) for p in self.adapter_modules.parameters()])) # Verify if the weights have changed assert orig_ip_proj_sum != new_ip_proj_sum, "Weights of image_proj_model did not change!" assert orig_adapter_sum != new_adapter_sum, "Weights of adapter_modules did not change!" print(f"Successfully loaded weights from checkpoint {ckpt_path}") class VideoProjModel(torch.nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=1, video_frame=50): super().__init__() self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) self.video_frame = video_frame def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class ImageProjModel(torch.nn.Module): """Projection Model""" def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): super().__init__() self.cross_attention_dim = cross_attention_dim self.clip_extra_context_tokens = clip_extra_context_tokens self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds): embeds = image_embeds clip_extra_context_tokens = self.proj(embeds).reshape( -1, self.clip_extra_context_tokens, self.cross_attention_dim ) clip_extra_context_tokens = self.norm(clip_extra_context_tokens) return clip_extra_context_tokens class MLPProjModel(torch.nn.Module): """SD model with image prompt""" def zero_initialize(module): for param in module.parameters(): param.data.zero_() def zero_initialize_last_layer(module): last_layer = None for module_name, layer in module.named_modules(): if isinstance(layer, torch.nn.Linear): last_layer = layer if last_layer is not None: last_layer.weight.data.zero_() last_layer.bias.data.zero_() def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): super().__init__() self.proj = torch.nn.Sequential( torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim), torch.nn.GELU(), torch.nn.Linear(clip_embeddings_dim, cross_attention_dim), torch.nn.LayerNorm(cross_attention_dim) ) # zero initialize the last layer # self.zero_initialize_last_layer() def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class V2AMapperMLP(torch.nn.Module): def __init__(self, cross_attention_dim=512, clip_embeddings_dim=512, mult=4): super().__init__() self.proj = torch.nn.Sequential( torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim * mult), torch.nn.GELU(), torch.nn.Linear(clip_embeddings_dim * mult, cross_attention_dim), torch.nn.LayerNorm(cross_attention_dim) ) def forward(self, image_embeds): clip_extra_context_tokens = self.proj(image_embeds) return clip_extra_context_tokens class TimeProjModel(torch.nn.Module): def __init__(self, positive_len, out_dim, feature_type="text-only", frame_nums:int=64): super().__init__() self.positive_len = positive_len self.out_dim = out_dim self.position_dim = frame_nums if isinstance(out_dim, tuple): out_dim = out_dim[0] if feature_type == "text-only": self.linears = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) elif feature_type == "text-image": self.linears_text = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.linears_image = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) # self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) def forward( self, boxes, masks, positive_embeddings=None, ): masks = masks.unsqueeze(-1) # # embedding position (it may includes padding as placeholder) # xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C # # learnable null embedding # xyxy_null = self.null_position_feature.view(1, 1, -1) # # replace padding with learnable null embedding # xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null time_embeds = boxes # positionet with text only information if positive_embeddings is not None: # learnable null embedding positive_null = self.null_positive_feature.view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null objs = self.linears(torch.cat([positive_embeddings, time_embeds], dim=-1)) # positionet with text and image infomation else: raise NotImplementedError return objs