# Copyright 2024 The HuggingFace Team. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from typing import Optional, Tuple import torch import torch.nn as nn import torch.nn.functional as F from ..utils import deprecate from .normalization import RMSNorm from .upsampling import upfirdn2d_native class Downsample1D(nn.Module): """A 1D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. name (`str`, default `conv`): name of the downsampling 1D layer. """ def __init__( self, channels: int, use_conv: bool = False, out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if use_conv: self.conv = nn.Conv1d(self.channels, self.out_channels, 3, stride=stride, padding=padding) else: assert self.channels == self.out_channels self.conv = nn.AvgPool1d(kernel_size=stride, stride=stride) def forward(self, inputs: torch.Tensor) -> torch.Tensor: assert inputs.shape[1] == self.channels return self.conv(inputs) class Downsample2D(nn.Module): """A 2D downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. padding (`int`, default `1`): padding for the convolution. name (`str`, default `conv`): name of the downsampling 2D layer. """ def __init__( self, channels: int, use_conv: bool = False, out_channels: Optional[int] = None, padding: int = 1, name: str = "conv", kernel_size=3, norm_type=None, eps=None, elementwise_affine=None, bias=True, ): super().__init__() self.channels = channels self.out_channels = out_channels or channels self.use_conv = use_conv self.padding = padding stride = 2 self.name = name if norm_type == "ln_norm": self.norm = nn.LayerNorm(channels, eps, elementwise_affine) elif norm_type == "rms_norm": self.norm = RMSNorm(channels, eps, elementwise_affine) elif norm_type is None: self.norm = None else: raise ValueError(f"unknown norm_type: {norm_type}") if use_conv: conv = nn.Conv2d( self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias ) else: assert self.channels == self.out_channels conv = nn.AvgPool2d(kernel_size=stride, stride=stride) # TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed if name == "conv": self.Conv2d_0 = conv self.conv = conv elif name == "Conv2d_0": self.conv = conv else: self.conv = conv def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor: if len(args) > 0 or kwargs.get("scale", None) is not None: deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." deprecate("scale", "1.0.0", deprecation_message) assert hidden_states.shape[1] == self.channels if self.norm is not None: hidden_states = self.norm(hidden_states.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) if self.use_conv and self.padding == 0: pad = (0, 1, 0, 1) hidden_states = F.pad(hidden_states, pad, mode="constant", value=0) assert hidden_states.shape[1] == self.channels hidden_states = self.conv(hidden_states) return hidden_states class FirDownsample2D(nn.Module): """A 2D FIR downsampling layer with an optional convolution. Parameters: channels (`int`): number of channels in the inputs and outputs. use_conv (`bool`, default `False`): option to use a convolution. out_channels (`int`, optional): number of output channels. Defaults to `channels`. fir_kernel (`tuple`, default `(1, 3, 3, 1)`): kernel for the FIR filter. """ def __init__( self, channels: Optional[int] = None, out_channels: Optional[int] = None, use_conv: bool = False, fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1), ): super().__init__() out_channels = out_channels if out_channels else channels if use_conv: self.Conv2d_0 = nn.Conv2d(channels, out_channels, kernel_size=3, stride=1, padding=1) self.fir_kernel = fir_kernel self.use_conv = use_conv self.out_channels = out_channels def _downsample_2d( self, hidden_states: torch.Tensor, weight: Optional[torch.Tensor] = None, kernel: Optional[torch.Tensor] = None, factor: int = 2, gain: float = 1, ) -> torch.Tensor: """Fused `Conv2d()` followed by `downsample_2d()`. Padding is performed only once at the beginning, not between the operations. The fused op is considerably more efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of arbitrary order. Args: hidden_states (`torch.Tensor`): Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. weight (`torch.Tensor`, *optional*): Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`. kernel (`torch.Tensor`, *optional*): FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor (`int`, *optional*, default to `2`): Integer downsampling factor. gain (`float`, *optional*, default to `1.0`): Scaling factor for signal magnitude. Returns: output (`torch.Tensor`): Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same datatype as `x`. """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor # setup kernel kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain if self.use_conv: _, _, convH, convW = weight.shape pad_value = (kernel.shape[0] - factor) + (convW - 1) stride_value = [factor, factor] upfirdn_input = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), pad=((pad_value + 1) // 2, pad_value // 2), ) output = F.conv2d(upfirdn_input, weight, stride=stride_value, padding=0) else: pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, torch.tensor(kernel, device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: if self.use_conv: downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel) hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1) else: hidden_states = self._downsample_2d(hidden_states, kernel=self.fir_kernel, factor=2) return hidden_states # downsample/upsample layer used in k-upscaler, might be able to use FirDownsample2D/DirUpsample2D instead class KDownsample2D(nn.Module): r"""A 2D K-downsampling layer. Parameters: pad_mode (`str`, *optional*, default to `"reflect"`): the padding mode to use. """ def __init__(self, pad_mode: str = "reflect"): super().__init__() self.pad_mode = pad_mode kernel_1d = torch.tensor([[1 / 8, 3 / 8, 3 / 8, 1 / 8]]) self.pad = kernel_1d.shape[1] // 2 - 1 self.register_buffer("kernel", kernel_1d.T @ kernel_1d, persistent=False) def forward(self, inputs: torch.Tensor) -> torch.Tensor: inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode) weight = inputs.new_zeros( [ inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1], ] ) indices = torch.arange(inputs.shape[1], device=inputs.device) kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1) weight[indices, indices] = kernel return F.conv2d(inputs, weight, stride=2) class CogVideoXDownsample3D(nn.Module): # Todo: Wait for paper relase. r""" A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI Args: in_channels (`int`): Number of channels in the input image. out_channels (`int`): Number of channels produced by the convolution. kernel_size (`int`, defaults to `3`): Size of the convolving kernel. stride (`int`, defaults to `2`): Stride of the convolution. padding (`int`, defaults to `0`): Padding added to all four sides of the input. compress_time (`bool`, defaults to `False`): Whether or not to compress the time dimension. """ def __init__( self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 2, padding: int = 0, compress_time: bool = False, ): super().__init__() self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding) self.compress_time = compress_time def forward(self, x: torch.Tensor) -> torch.Tensor: if self.compress_time: batch_size, channels, frames, height, width = x.shape # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames) x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames) if x.shape[-1] % 2 == 1: x_first, x_rest = x[..., 0], x[..., 1:] if x_rest.shape[-1] > 0: # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2) x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2) x = torch.cat([x_first[..., None], x_rest], dim=-1) # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width) x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) else: # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2) x = F.avg_pool1d(x, kernel_size=2, stride=2) # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width) x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2) # Pad the tensor pad = (0, 1, 0, 1) x = F.pad(x, pad, mode="constant", value=0) batch_size, channels, frames, height, width = x.shape # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width) x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width) x = self.conv(x) # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width) x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4) return x def downsample_2d( hidden_states: torch.Tensor, kernel: Optional[torch.Tensor] = None, factor: int = 2, gain: float = 1, ) -> torch.Tensor: r"""Downsample2D a batch of 2D images with the given filter. Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified `gain`. Pixels outside the image are assumed to be zero, and the filter is padded with zeros so that its shape is a multiple of the downsampling factor. Args: hidden_states (`torch.Tensor`) Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`. kernel (`torch.Tensor`, *optional*): FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which corresponds to average pooling. factor (`int`, *optional*, default to `2`): Integer downsampling factor. gain (`float`, *optional*, default to `1.0`): Scaling factor for signal magnitude. Returns: output (`torch.Tensor`): Tensor of the shape `[N, C, H // factor, W // factor]` """ assert isinstance(factor, int) and factor >= 1 if kernel is None: kernel = [1] * factor kernel = torch.tensor(kernel, dtype=torch.float32) if kernel.ndim == 1: kernel = torch.outer(kernel, kernel) kernel /= torch.sum(kernel) kernel = kernel * gain pad_value = kernel.shape[0] - factor output = upfirdn2d_native( hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2), ) return output