# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import math from typing import List, Optional, Tuple, Union import numpy as np import torch import torch.nn.functional as F from torch import nn from ..utils import deprecate from .activations import FP32SiLU, get_activation from .attention_processor import Attention def get_timestep_embedding( timesteps: torch.Tensor, embedding_dim: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1, scale: float = 1, max_period: int = 10000, ): """ This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. Args timesteps (torch.Tensor): a 1-D Tensor of N indices, one per batch element. These may be fractional. embedding_dim (int): the dimension of the output. flip_sin_to_cos (bool): Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False) downscale_freq_shift (float): Controls the delta between frequencies between dimensions scale (float): Scaling factor applied to the embeddings. max_period (int): Controls the maximum frequency of the embeddings Returns torch.Tensor: an [N x dim] Tensor of positional embeddings. """ assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" half_dim = embedding_dim // 2 exponent = -math.log(max_period) * torch.arange( start=0, end=half_dim, dtype=torch.float32, device=timesteps.device ) exponent = exponent / (half_dim - downscale_freq_shift) emb = torch.exp(exponent) emb = timesteps[:, None].float() * emb[None, :] # scale embeddings emb = scale * emb # concat sine and cosine embeddings emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) # flip sine and cosine embeddings if flip_sin_to_cos: emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) # zero pad if embedding_dim % 2 == 1: emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def get_3d_sincos_pos_embed( embed_dim: int, spatial_size: Union[int, Tuple[int, int]], temporal_size: int, spatial_interpolation_scale: float = 1.0, temporal_interpolation_scale: float = 1.0, ) -> np.ndarray: r""" Args: embed_dim (`int`): spatial_size (`int` or `Tuple[int, int]`): temporal_size (`int`): spatial_interpolation_scale (`float`, defaults to 1.0): temporal_interpolation_scale (`float`, defaults to 1.0): """ if embed_dim % 4 != 0: raise ValueError("`embed_dim` must be divisible by 4") if isinstance(spatial_size, int): spatial_size = (spatial_size, spatial_size) embed_dim_spatial = 3 * embed_dim // 4 embed_dim_temporal = embed_dim // 4 # 1. Spatial grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]]) pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid) # 2. Temporal grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t) # 3. Concat pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :] pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3] pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :] pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4] pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D] return pos_embed def get_2d_sincos_pos_embed( embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16 ): """ grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) """ if isinstance(grid_size, int): grid_size = (grid_size, grid_size) grid_h = np.arange(grid_size[0], dtype=np.float32) / (grid_size[0] / base_size) / interpolation_scale grid_w = np.arange(grid_size[1], dtype=np.float32) / (grid_size[1] / base_size) / interpolation_scale grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) grid = grid.reshape([2, 1, grid_size[1], grid_size[0]]) pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) # [grid_size**2, embed_dim] if cls_token and extra_tokens > 0: pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) return pos_embed def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") # use half of dimensions to encode grid_h emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) """ if embed_dim % 2 != 0: raise ValueError("embed_dim must be divisible by 2") omega = np.arange(embed_dim // 2, dtype=np.float64) omega /= embed_dim / 2.0 omega = 1.0 / 10000**omega # (D/2,) pos = pos.reshape(-1) # (M,) out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product emb_sin = np.sin(out) # (M, D/2) emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb class PatchEmbed(nn.Module): """2D Image to Patch Embedding with support for SD3 cropping.""" def __init__( self, height=224, width=224, patch_size=16, in_channels=3, embed_dim=768, layer_norm=False, flatten=True, bias=True, interpolation_scale=1, pos_embed_type="sincos", pos_embed_max_size=None, # For SD3 cropping ): super().__init__() num_patches = (height // patch_size) * (width // patch_size) self.flatten = flatten self.layer_norm = layer_norm self.pos_embed_max_size = pos_embed_max_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) if layer_norm: self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False, eps=1e-6) else: self.norm = None self.patch_size = patch_size self.height, self.width = height // patch_size, width // patch_size self.base_size = height // patch_size self.interpolation_scale = interpolation_scale # Calculate positional embeddings based on max size or default if pos_embed_max_size: grid_size = pos_embed_max_size else: grid_size = int(num_patches**0.5) if pos_embed_type is None: self.pos_embed = None elif pos_embed_type == "sincos": pos_embed = get_2d_sincos_pos_embed( embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale ) # [grid_size**2, embed_dim] persistent = True if pos_embed_max_size else False self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent) else: raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}") def cropped_pos_embed(self, height, width): """Crops positional embeddings for SD3 compatibility.""" if self.pos_embed_max_size is None: raise ValueError("`pos_embed_max_size` must be set for cropping.") height = height // self.patch_size width = width // self.patch_size if height > self.pos_embed_max_size: raise ValueError( f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) if width > self.pos_embed_max_size: raise ValueError( f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}." ) top = (self.pos_embed_max_size - height) // 2 left = (self.pos_embed_max_size - width) // 2 spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1) spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :] spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1]) return spatial_pos_embed def forward(self, latent): if self.pos_embed_max_size is not None: # 192 height, width = latent.shape[-2:] else: height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size latent = self.proj(latent) if self.flatten: # True latent = latent.flatten(2).transpose(1, 2) # BCHW -> BNC if self.layer_norm: latent = self.norm(latent) if self.pos_embed is None: return latent.to(latent.dtype) # Interpolate or crop positional embeddings as needed if self.pos_embed_max_size: pos_embed = self.cropped_pos_embed(height, width) else: if self.height != height or self.width != width: pos_embed = get_2d_sincos_pos_embed( embed_dim=self.pos_embed.shape[-1], grid_size=(height, width), base_size=self.base_size, interpolation_scale=self.interpolation_scale, ) pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device) else: pos_embed = self.pos_embed return (latent + pos_embed).to(latent.dtype) class LuminaPatchEmbed(nn.Module): """2D Image to Patch Embedding with support for Lumina-T2X""" def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True): super().__init__() self.patch_size = patch_size self.proj = nn.Linear( in_features=patch_size * patch_size * in_channels, out_features=embed_dim, bias=bias, ) def forward(self, x, freqs_cis): """ Patchifies and embeds the input tensor(s). Args: x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded. Returns: Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the frequency tensor(s). """ freqs_cis = freqs_cis.to(x[0].device) patch_height = patch_width = self.patch_size batch_size, channel, height, width = x.size() height_tokens, width_tokens = height // patch_height, width // patch_width x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute( 0, 2, 4, 1, 3, 5 ) x = x.flatten(3) x = self.proj(x) x = x.flatten(1, 2) mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device) return ( x, mask, [(height, width)] * batch_size, freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0), ) class CogVideoXPatchEmbed(nn.Module): def __init__( self, patch_size: int = 2, in_channels: int = 16, embed_dim: int = 1920, text_embed_dim: int = 4096, bias: bool = True, ) -> None: super().__init__() self.patch_size = patch_size self.proj = nn.Conv2d( in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias ) self.text_proj = nn.Linear(text_embed_dim, embed_dim) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): r""" Args: text_embeds (`torch.Tensor`): Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim). image_embeds (`torch.Tensor`): Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width). """ text_embeds = self.text_proj(text_embeds) batch, num_frames, channels, height, width = image_embeds.shape image_embeds = image_embeds.reshape(-1, channels, height, width) image_embeds = self.proj(image_embeds) image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:]) image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels] image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels] embeds = torch.cat( [text_embeds, image_embeds], dim=1 ).contiguous() # [batch, seq_length + num_frames x height x width, channels] return embeds def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True): """ RoPE for image tokens with 2d structure. Args: embed_dim: (`int`): The embedding dimension size crops_coords (`Tuple[int]`) The top-left and bottom-right coordinates of the crop. grid_size (`Tuple[int]`): The grid size of the positional embedding. use_real (`bool`): If True, return real part and imaginary part separately. Otherwise, return complex numbers. Returns: `torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`. """ start, stop = crops_coords grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32) grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32) grid = np.meshgrid(grid_w, grid_h) # here w goes first grid = np.stack(grid, axis=0) # [2, W, H] grid = grid.reshape([2, 1, *grid.shape[1:]]) pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real) return pos_embed def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False): assert embed_dim % 4 == 0 # use half of dimensions to encode grid_h emb_h = get_1d_rotary_pos_embed( embed_dim // 2, grid[0].reshape(-1), use_real=use_real ) # (H*W, D/2) if use_real else (H*W, D/4) emb_w = get_1d_rotary_pos_embed( embed_dim // 2, grid[1].reshape(-1), use_real=use_real ) # (H*W, D/2) if use_real else (H*W, D/4) if use_real: cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D) sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D) return cos, sin else: emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2) return emb def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0): assert embed_dim % 4 == 0 emb_h = get_1d_rotary_pos_embed( embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor ) # (H, D/4) emb_w = get_1d_rotary_pos_embed( embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor ) # (W, D/4) emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1) emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1) emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2) return emb def get_1d_rotary_pos_embed( dim: int, pos: Union[np.ndarray, int], theta: float = 10000.0, use_real=False, linear_factor=1.0, ntk_factor=1.0, repeat_interleave_real=True, ): """ Precompute the frequency tensor for complex exponentials (cis) with given dimensions. This function calculates a frequency tensor with complex exponentials using the given dimension 'dim' and the end index 'end'. The 'theta' parameter scales the frequencies. The returned tensor contains complex values in complex64 data type. Args: dim (`int`): Dimension of the frequency tensor. pos (`np.ndarray` or `int`): Position indices for the frequency tensor. [S] or scalar theta (`float`, *optional*, defaults to 10000.0): Scaling factor for frequency computation. Defaults to 10000.0. use_real (`bool`, *optional*): If True, return real part and imaginary part separately. Otherwise, return complex numbers. linear_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the context extrapolation. Defaults to 1.0. ntk_factor (`float`, *optional*, defaults to 1.0): Scaling factor for the NTK-Aware RoPE. Defaults to 1.0. repeat_interleave_real (`bool`, *optional*, defaults to `True`): If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`. Otherwise, they are concateanted with themselves. Returns: `torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2] """ assert dim % 2 == 0 if isinstance(pos, int): pos = np.arange(pos) theta = theta * ntk_factor freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2] t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S] freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2] if use_real and repeat_interleave_real: freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D] freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D] return freqs_cos, freqs_sin elif use_real: freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D] freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D] return freqs_cos, freqs_sin else: freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2] return freqs_cis def apply_rotary_emb( x: torch.Tensor, freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]], use_real: bool = True, use_real_unbind_dim: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are returned as real tensors. Args: x (`torch.Tensor`): Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],) Returns: Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings. """ if use_real: cos, sin = freqs_cis # [S, D] cos = cos[None, None] sin = sin[None, None] cos, sin = cos.to(x.device), sin.to(x.device) if use_real_unbind_dim == -1: # Use for example in Lumina x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2] x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3) elif use_real_unbind_dim == -2: # Use for example in Stable Audio x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2] x_rotated = torch.cat([-x_imag, x_real], dim=-1) else: raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.") out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype) return out else: x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2)) freqs_cis = freqs_cis.unsqueeze(2) x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3) return x_out.type_as(x) class TimestepEmbedding(nn.Module): def __init__( self, in_channels: int, time_embed_dim: int, act_fn: str = "silu", out_dim: int = None, post_act_fn: Optional[str] = None, cond_proj_dim=None, sample_proj_bias=True, ): super().__init__() self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias) if cond_proj_dim is not None: self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False) else: self.cond_proj = None self.act = get_activation(act_fn) if out_dim is not None: time_embed_dim_out = out_dim else: time_embed_dim_out = time_embed_dim self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias) if post_act_fn is None: self.post_act = None else: self.post_act = get_activation(post_act_fn) def forward(self, sample, condition=None): if condition is not None: sample = sample + self.cond_proj(condition) sample = self.linear_1(sample) if self.act is not None: sample = self.act(sample) sample = self.linear_2(sample) if self.post_act is not None: sample = self.post_act(sample) return sample class Timesteps(nn.Module): def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1): super().__init__() self.num_channels = num_channels self.flip_sin_to_cos = flip_sin_to_cos self.downscale_freq_shift = downscale_freq_shift self.scale = scale def forward(self, timesteps): t_emb = get_timestep_embedding( timesteps, self.num_channels, flip_sin_to_cos=self.flip_sin_to_cos, downscale_freq_shift=self.downscale_freq_shift, scale=self.scale, ) return t_emb class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" def __init__( self, embedding_size: int = 256, scale: float = 1.0, set_W_to_weight=True, log=True, flip_sin_to_cos=False ): super().__init__() self.weight = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.log = log self.flip_sin_to_cos = flip_sin_to_cos if set_W_to_weight: # to delete later del self.weight self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False) self.weight = self.W del self.W def forward(self, x): if self.log: x = torch.log(x) x_proj = x[:, None] * self.weight[None, :] * 2 * np.pi if self.flip_sin_to_cos: out = torch.cat([torch.cos(x_proj), torch.sin(x_proj)], dim=-1) else: out = torch.cat([torch.sin(x_proj), torch.cos(x_proj)], dim=-1) return out class SinusoidalPositionalEmbedding(nn.Module): """Apply positional information to a sequence of embeddings. Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to them Args: embed_dim: (int): Dimension of the positional embedding. max_seq_length: Maximum sequence length to apply positional embeddings """ def __init__(self, embed_dim: int, max_seq_length: int = 32): super().__init__() position = torch.arange(max_seq_length).unsqueeze(1) div_term = torch.exp(torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim)) pe = torch.zeros(1, max_seq_length, embed_dim) pe[0, :, 0::2] = torch.sin(position * div_term) pe[0, :, 1::2] = torch.cos(position * div_term) self.register_buffer("pe", pe) def forward(self, x): _, seq_length, _ = x.shape x = x + self.pe[:, :seq_length] return x class ImagePositionalEmbeddings(nn.Module): """ Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the height and width of the latent space. For more details, see figure 10 of the dall-e paper: https://arxiv.org/abs/2102.12092 For VQ-diffusion: Output vector embeddings are used as input for the transformer. Note that the vector embeddings for the transformer are different than the vector embeddings from the VQVAE. Args: num_embed (`int`): Number of embeddings for the latent pixels embeddings. height (`int`): Height of the latent image i.e. the number of height embeddings. width (`int`): Width of the latent image i.e. the number of width embeddings. embed_dim (`int`): Dimension of the produced vector embeddings. Used for the latent pixel, height, and width embeddings. """ def __init__( self, num_embed: int, height: int, width: int, embed_dim: int, ): super().__init__() self.height = height self.width = width self.num_embed = num_embed self.embed_dim = embed_dim self.emb = nn.Embedding(self.num_embed, embed_dim) self.height_emb = nn.Embedding(self.height, embed_dim) self.width_emb = nn.Embedding(self.width, embed_dim) def forward(self, index): emb = self.emb(index) height_emb = self.height_emb(torch.arange(self.height, device=index.device).view(1, self.height)) # 1 x H x D -> 1 x H x 1 x D height_emb = height_emb.unsqueeze(2) width_emb = self.width_emb(torch.arange(self.width, device=index.device).view(1, self.width)) # 1 x W x D -> 1 x 1 x W x D width_emb = width_emb.unsqueeze(1) pos_emb = height_emb + width_emb # 1 x H x W x D -> 1 x L xD pos_emb = pos_emb.view(1, self.height * self.width, -1) emb = emb + pos_emb[:, : emb.shape[1], :] return emb class LabelEmbedding(nn.Module): """ Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance. Args: num_classes (`int`): The number of classes. hidden_size (`int`): The size of the vector embeddings. dropout_prob (`float`): The probability of dropping a label. """ def __init__(self, num_classes, hidden_size, dropout_prob): super().__init__() use_cfg_embedding = dropout_prob > 0 self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) self.num_classes = num_classes self.dropout_prob = dropout_prob def token_drop(self, labels, force_drop_ids=None): """ Drops labels to enable classifier-free guidance. """ if force_drop_ids is None: drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob else: drop_ids = torch.tensor(force_drop_ids == 1) labels = torch.where(drop_ids, self.num_classes, labels) return labels def forward(self, labels: torch.LongTensor, force_drop_ids=None): use_dropout = self.dropout_prob > 0 if (self.training and use_dropout) or (force_drop_ids is not None): labels = self.token_drop(labels, force_drop_ids) embeddings = self.embedding_table(labels) return embeddings class TextImageProjection(nn.Module): def __init__( self, text_embed_dim: int = 1024, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 10, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): batch_size = text_embeds.shape[0] # image image_text_embeds = self.image_embeds(image_embeds) image_text_embeds = image_text_embeds.reshape(batch_size, self.num_image_text_embeds, -1) # text text_embeds = self.text_proj(text_embeds) return torch.cat([image_text_embeds, text_embeds], dim=1) class ImageProjection(nn.Module): def __init__( self, image_embed_dim: int = 768, cross_attention_dim: int = 768, num_image_text_embeds: int = 32, ): super().__init__() self.num_image_text_embeds = num_image_text_embeds self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim) self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.Tensor): batch_size = image_embeds.shape[0] # image image_embeds = self.image_embeds(image_embeds) image_embeds = image_embeds.reshape(batch_size, self.num_image_text_embeds, -1) image_embeds = self.norm(image_embeds) return image_embeds class IPAdapterFullImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024): super().__init__() from .attention import FeedForward self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.Tensor): return self.norm(self.ff(image_embeds)) class IPAdapterFaceIDImageProjection(nn.Module): def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1): super().__init__() from .attention import FeedForward self.num_tokens = num_tokens self.cross_attention_dim = cross_attention_dim self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu") self.norm = nn.LayerNorm(cross_attention_dim) def forward(self, image_embeds: torch.Tensor): x = self.ff(image_embeds) x = x.reshape(-1, self.num_tokens, self.cross_attention_dim) return self.norm(x) class CombinedTimestepLabelEmbeddings(nn.Module): def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=1) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.class_embedder = LabelEmbedding(num_classes, embedding_dim, class_dropout_prob) def forward(self, timestep, class_labels, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) class_labels = self.class_embedder(class_labels) # (N, D) conditioning = timesteps_emb + class_labels # (N, D) return conditioning class CombinedTimestepTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) pooled_projections = self.text_embedder(pooled_projection) conditioning = timesteps_emb + pooled_projections return conditioning class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module): def __init__(self, embedding_dim, pooled_projection_dim): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu") def forward(self, timestep, guidance, pooled_projection): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D) guidance_proj = self.time_proj(guidance) guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D) time_guidance_emb = timesteps_emb + guidance_emb pooled_projections = self.text_embedder(pooled_projection) conditioning = time_guidance_emb + pooled_projections return conditioning class HunyuanDiTAttentionPool(nn.Module): # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6 def __init__(self, spacial_dim: int, embed_dim: int, num_heads: int, output_dim: int = None): super().__init__() self.positional_embedding = nn.Parameter(torch.randn(spacial_dim + 1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim) self.q_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) self.c_proj = nn.Linear(embed_dim, output_dim or embed_dim) self.num_heads = num_heads def forward(self, x): x = x.permute(1, 0, 2) # NLC -> LNC x = torch.cat([x.mean(dim=0, keepdim=True), x], dim=0) # (L+1)NC x = x + self.positional_embedding[:, None, :].to(x.dtype) # (L+1)NC x, _ = F.multi_head_attention_forward( query=x[:1], key=x, value=x, embed_dim_to_check=x.shape[-1], num_heads=self.num_heads, q_proj_weight=self.q_proj.weight, k_proj_weight=self.k_proj.weight, v_proj_weight=self.v_proj.weight, in_proj_weight=None, in_proj_bias=torch.cat([self.q_proj.bias, self.k_proj.bias, self.v_proj.bias]), bias_k=None, bias_v=None, add_zero_attn=False, dropout_p=0, out_proj_weight=self.c_proj.weight, out_proj_bias=self.c_proj.bias, use_separate_proj_weight=True, training=self.training, need_weights=False, ) return x.squeeze(0) class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module): def __init__( self, embedding_dim, pooled_projection_dim=1024, seq_len=256, cross_attention_dim=2048, use_style_cond_and_image_meta_size=True, ): super().__init__() self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.pooler = HunyuanDiTAttentionPool( seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim ) # Here we use a default learned embedder layer for future extension. self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size if use_style_cond_and_image_meta_size: self.style_embedder = nn.Embedding(1, embedding_dim) extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim else: extra_in_dim = pooled_projection_dim self.extra_embedder = PixArtAlphaTextProjection( in_features=extra_in_dim, hidden_size=embedding_dim * 4, out_features=embedding_dim, act_fn="silu_fp32", ) def forward(self, timestep, encoder_hidden_states, image_meta_size, style, hidden_dtype=None): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, 256) # extra condition1: text pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024) if self.use_style_cond_and_image_meta_size: # extra condition2: image meta size embedding image_meta_size = self.size_proj(image_meta_size.view(-1)) image_meta_size = image_meta_size.to(dtype=hidden_dtype) image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536) # extra condition3: style embedding style_embedding = self.style_embedder(style) # (N, embedding_dim) # Concatenate all extra vectors extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1) else: extra_cond = torch.cat([pooled_projections], dim=1) conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D] return conditioning class LuminaCombinedTimestepCaptionEmbedding(nn.Module): def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256): super().__init__() self.time_proj = Timesteps( num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0 ) self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size) self.caption_embedder = nn.Sequential( nn.LayerNorm(cross_attention_dim), nn.Linear( cross_attention_dim, hidden_size, bias=True, ), ) def forward(self, timestep, caption_feat, caption_mask): # timestep embedding: time_freq = self.time_proj(timestep) time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype)) # caption condition embedding: caption_mask_float = caption_mask.float().unsqueeze(-1) caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1) caption_feats_pool = caption_feats_pool.to(caption_feat) caption_embed = self.caption_embedder(caption_feats_pool) conditioning = time_embed + caption_embed return conditioning class TextTimeEmbedding(nn.Module): def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64): super().__init__() self.norm1 = nn.LayerNorm(encoder_dim) self.pool = AttentionPooling(num_heads, encoder_dim) self.proj = nn.Linear(encoder_dim, time_embed_dim) self.norm2 = nn.LayerNorm(time_embed_dim) def forward(self, hidden_states): hidden_states = self.norm1(hidden_states) hidden_states = self.pool(hidden_states) hidden_states = self.proj(hidden_states) hidden_states = self.norm2(hidden_states) return hidden_states class TextImageTimeEmbedding(nn.Module): def __init__(self, text_embed_dim: int = 768, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.text_proj = nn.Linear(text_embed_dim, time_embed_dim) self.text_norm = nn.LayerNorm(time_embed_dim) self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor): # text time_text_embeds = self.text_proj(text_embeds) time_text_embeds = self.text_norm(time_text_embeds) # image time_image_embeds = self.image_proj(image_embeds) return time_image_embeds + time_text_embeds class ImageTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) def forward(self, image_embeds: torch.Tensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) return time_image_embeds class ImageHintTimeEmbedding(nn.Module): def __init__(self, image_embed_dim: int = 768, time_embed_dim: int = 1536): super().__init__() self.image_proj = nn.Linear(image_embed_dim, time_embed_dim) self.image_norm = nn.LayerNorm(time_embed_dim) self.input_hint_block = nn.Sequential( nn.Conv2d(3, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 16, 3, padding=1), nn.SiLU(), nn.Conv2d(16, 32, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(32, 32, 3, padding=1), nn.SiLU(), nn.Conv2d(32, 96, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(96, 96, 3, padding=1), nn.SiLU(), nn.Conv2d(96, 256, 3, padding=1, stride=2), nn.SiLU(), nn.Conv2d(256, 4, 3, padding=1), ) def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor): # image time_image_embeds = self.image_proj(image_embeds) time_image_embeds = self.image_norm(time_image_embeds) hint = self.input_hint_block(hint) return time_image_embeds, hint class AttentionPooling(nn.Module): # Copied from https://github.com/deep-floyd/IF/blob/2f91391f27dd3c468bf174be5805b4cc92980c0b/deepfloyd_if/model/nn.py#L54 def __init__(self, num_heads, embed_dim, dtype=None): super().__init__() self.dtype = dtype self.positional_embedding = nn.Parameter(torch.randn(1, embed_dim) / embed_dim**0.5) self.k_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.q_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.v_proj = nn.Linear(embed_dim, embed_dim, dtype=self.dtype) self.num_heads = num_heads self.dim_per_head = embed_dim // self.num_heads def forward(self, x): bs, length, width = x.size() def shape(x): # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, -1, self.num_heads, self.dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) x = x.transpose(1, 2) # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head) x = x.reshape(bs * self.num_heads, -1, self.dim_per_head) # (bs*n_heads, length, dim_per_head) --> (bs*n_heads, dim_per_head, length) x = x.transpose(1, 2) return x class_token = x.mean(dim=1, keepdim=True) + self.positional_embedding.to(x.dtype) x = torch.cat([class_token, x], dim=1) # (bs, length+1, width) # (bs*n_heads, class_token_length, dim_per_head) q = shape(self.q_proj(class_token)) # (bs*n_heads, length+class_token_length, dim_per_head) k = shape(self.k_proj(x)) v = shape(self.v_proj(x)) # (bs*n_heads, class_token_length, length+class_token_length): scale = 1 / math.sqrt(math.sqrt(self.dim_per_head)) weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype) # (bs*n_heads, dim_per_head, class_token_length) a = torch.einsum("bts,bcs->bct", weight, v) # (bs, length+1, width) a = a.reshape(bs, -1, 1).transpose(1, 2) return a[:, 0, :] # cls_token def get_fourier_embeds_from_boundingbox(embed_dim, box): """ Args: embed_dim: int box: a 3-D tensor [B x N x 4] representing the bounding boxes for GLIGEN pipeline Returns: [B x N x embed_dim] tensor of positional embeddings """ batch_size, num_boxes = box.shape[:2] emb = 100 ** (torch.arange(embed_dim) / embed_dim) emb = emb[None, None, None].to(device=box.device, dtype=box.dtype) emb = emb * box.unsqueeze(-1) emb = torch.stack((emb.sin(), emb.cos()), dim=-1) emb = emb.permute(0, 1, 3, 4, 2).reshape(batch_size, num_boxes, embed_dim * 2 * 4) return emb class GLIGENTextBoundingboxProjection(nn.Module): def __init__(self, positive_len, out_dim, feature_type="text-only", fourier_freqs=8): super().__init__() self.positive_len = positive_len self.out_dim = out_dim self.fourier_embedder_dim = fourier_freqs self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy if isinstance(out_dim, tuple): out_dim = out_dim[0] if feature_type == "text-only": self.linears = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) elif feature_type == "text-image": self.linears_text = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.linears_image = nn.Sequential( nn.Linear(self.positive_len + self.position_dim, 512), nn.SiLU(), nn.Linear(512, 512), nn.SiLU(), nn.Linear(512, out_dim), ) self.null_text_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_image_feature = torch.nn.Parameter(torch.zeros([self.positive_len])) self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim])) def forward( self, boxes, masks, positive_embeddings=None, phrases_masks=None, image_masks=None, phrases_embeddings=None, image_embeddings=None, ): masks = masks.unsqueeze(-1) # embedding position (it may includes padding as placeholder) xyxy_embedding = get_fourier_embeds_from_boundingbox(self.fourier_embedder_dim, boxes) # B*N*4 -> B*N*C # learnable null embedding xyxy_null = self.null_position_feature.view(1, 1, -1) # replace padding with learnable null embedding xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null # positionet with text only information if positive_embeddings is not None: # learnable null embedding positive_null = self.null_positive_feature.view(1, 1, -1) # replace padding with learnable null embedding positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1)) # positionet with text and image information else: phrases_masks = phrases_masks.unsqueeze(-1) image_masks = image_masks.unsqueeze(-1) # learnable null embedding text_null = self.null_text_feature.view(1, 1, -1) image_null = self.null_image_feature.view(1, 1, -1) # replace padding with learnable null embedding phrases_embeddings = phrases_embeddings * phrases_masks + (1 - phrases_masks) * text_null image_embeddings = image_embeddings * image_masks + (1 - image_masks) * image_null objs_text = self.linears_text(torch.cat([phrases_embeddings, xyxy_embedding], dim=-1)) objs_image = self.linears_image(torch.cat([image_embeddings, xyxy_embedding], dim=-1)) objs = torch.cat([objs_text, objs_image], dim=1) return objs class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module): """ For PixArt-Alpha. Reference: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29 """ def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False): super().__init__() self.outdim = size_emb_dim self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim) self.use_additional_conditions = use_additional_conditions if use_additional_conditions: self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0) self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim) def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype): timesteps_proj = self.time_proj(timestep) timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D) if self.use_additional_conditions: resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype) resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1) aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype) aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1) conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1) else: conditioning = timesteps_emb return conditioning class PixArtAlphaTextProjection(nn.Module): """ Projects caption embeddings. Also handles dropout for classifier-free guidance. Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py """ def __init__(self, in_features, hidden_size, out_features=None, act_fn="gelu_tanh"): super().__init__() if out_features is None: out_features = hidden_size self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True) if act_fn == "gelu_tanh": self.act_1 = nn.GELU(approximate="tanh") elif act_fn == "silu": self.act_1 = nn.SiLU() elif act_fn == "silu_fp32": self.act_1 = FP32SiLU() else: raise ValueError(f"Unknown activation function: {act_fn}") self.linear_2 = nn.Linear(in_features=hidden_size, out_features=out_features, bias=True) def forward(self, caption): hidden_states = self.linear_1(caption) hidden_states = self.act_1(hidden_states) hidden_states = self.linear_2(hidden_states) return hidden_states class IPAdapterPlusImageProjectionBlock(nn.Module): def __init__( self, embed_dims: int = 768, dim_head: int = 64, heads: int = 16, ffn_ratio: float = 4, ) -> None: super().__init__() from .attention import FeedForward self.ln0 = nn.LayerNorm(embed_dims) self.ln1 = nn.LayerNorm(embed_dims) self.attn = Attention( query_dim=embed_dims, dim_head=dim_head, heads=heads, out_bias=False, ) self.ff = nn.Sequential( nn.LayerNorm(embed_dims), FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False), ) def forward(self, x, latents, residual): encoder_hidden_states = self.ln0(x) latents = self.ln1(latents) encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2) latents = self.attn(latents, encoder_hidden_states) + residual latents = self.ff(latents) + latents return latents class IPAdapterPlusImageProjection(nn.Module): """Resampler of IP-Adapter Plus. Args: embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, that is the same number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. Defaults to 16. num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. """ def __init__( self, embed_dims: int = 768, output_dims: int = 1024, hidden_dims: int = 1280, depth: int = 4, dim_head: int = 64, heads: int = 16, num_queries: int = 8, ffn_ratio: float = 4, ) -> None: super().__init__() self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5) self.proj_in = nn.Linear(embed_dims, hidden_dims) self.proj_out = nn.Linear(hidden_dims, output_dims) self.norm_out = nn.LayerNorm(output_dims) self.layers = nn.ModuleList( [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] ) def forward(self, x: torch.Tensor) -> torch.Tensor: """Forward pass. Args: x (torch.Tensor): Input Tensor. Returns: torch.Tensor: Output Tensor. """ latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) for block in self.layers: residual = latents latents = block(x, latents, residual) latents = self.proj_out(latents) return self.norm_out(latents) class IPAdapterFaceIDPlusImageProjection(nn.Module): """FacePerceiverResampler of IP-Adapter Plus. Args: embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels, that is the same number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024. hidden_dims (int): The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads. Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio of feedforward network hidden layer channels. Defaults to 4. ffproj_ratio (float): The expansion ratio of feedforward network hidden layer channels (for ID embeddings). Defaults to 4. """ def __init__( self, embed_dims: int = 768, output_dims: int = 768, hidden_dims: int = 1280, id_embeddings_dim: int = 512, depth: int = 4, dim_head: int = 64, heads: int = 16, num_tokens: int = 4, num_queries: int = 8, ffn_ratio: float = 4, ffproj_ratio: int = 2, ) -> None: super().__init__() from .attention import FeedForward self.num_tokens = num_tokens self.embed_dim = embed_dims self.clip_embeds = None self.shortcut = False self.shortcut_scale = 1.0 self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio) self.norm = nn.LayerNorm(embed_dims) self.proj_in = nn.Linear(hidden_dims, embed_dims) self.proj_out = nn.Linear(embed_dims, output_dims) self.norm_out = nn.LayerNorm(output_dims) self.layers = nn.ModuleList( [IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)] ) def forward(self, id_embeds: torch.Tensor) -> torch.Tensor: """Forward pass. Args: id_embeds (torch.Tensor): Input Tensor (ID embeds). Returns: torch.Tensor: Output Tensor. """ id_embeds = id_embeds.to(self.clip_embeds.dtype) id_embeds = self.proj(id_embeds) id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim) id_embeds = self.norm(id_embeds) latents = id_embeds clip_embeds = self.proj_in(self.clip_embeds) x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3]) for block in self.layers: residual = latents latents = block(x, latents, residual) latents = self.proj_out(latents) out = self.norm_out(latents) if self.shortcut: out = id_embeds + self.shortcut_scale * out return out class MultiIPAdapterImageProjection(nn.Module): def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]): super().__init__() self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers) def forward(self, image_embeds: List[torch.Tensor]): projected_image_embeds = [] # currently, we accept `image_embeds` as # 1. a tensor (deprecated) with shape [batch_size, embed_dim] or [batch_size, sequence_length, embed_dim] # 2. list of `n` tensors where `n` is number of ip-adapters, each tensor can hae shape [batch_size, num_images, embed_dim] or [batch_size, num_images, sequence_length, embed_dim] if not isinstance(image_embeds, list): deprecation_message = ( "You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release." " Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning." ) deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False) image_embeds = [image_embeds.unsqueeze(1)] if len(image_embeds) != len(self.image_projection_layers): raise ValueError( f"image_embeds must have the same length as image_projection_layers, got {len(image_embeds)} and {len(self.image_projection_layers)}" ) for image_embed, image_projection_layer in zip(image_embeds, self.image_projection_layers): batch_size, num_images = image_embed.shape[0], image_embed.shape[1] image_embed = image_embed.reshape((batch_size * num_images,) + image_embed.shape[2:]) image_embed = image_projection_layer(image_embed) image_embed = image_embed.reshape((batch_size, num_images) + image_embed.shape[1:]) projected_image_embeds.append(image_embed) return projected_image_embeds