import torch import torch.fft as fft from torch import nn from torch.nn import functional from math import sqrt from einops import rearrange import math import numbers from typing import List # adapted from https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/10 # and https://discuss.pytorch.org/t/is-there-anyway-to-do-gaussian-filtering-for-an-image-2d-3d-in-pytorch/12351/19 def gaussian_smoothing_kernel(shape, kernel_size, sigma, dim=2): """ Apply gaussian smoothing on a 1d, 2d or 3d tensor. Filtering is performed seperately for each channel in the input using a depthwise convolution. Arguments: channels (int, sequence): Number of channels of the input tensors. Output will have this number of channels as well. kernel_size (int, sequence): Size of the gaussian kernel. sigma (float, sequence): Standard deviation of the gaussian kernel. dim (int, optional): The number of dimensions of the data. Default value is 2 (spatial). """ if isinstance(kernel_size, numbers.Number): kernel_size = [kernel_size] * dim if isinstance(sigma, numbers.Number): sigma = [sigma] * dim # The gaussian kernel is the product of the # gaussian function of each dimension. kernel = 1 meshgrids = torch.meshgrid( [ torch.arange(size, dtype=torch.float32) for size in kernel_size ] ) for size, std, mgrid in zip(kernel_size, sigma, meshgrids): mean = (size - 1) / 2 kernel *= torch.exp(-((mgrid - mean) / std) ** 2 / 2) # kernel *= 1 / (std * math.sqrt(2 * math.pi)) * \ # torch.exp(-((mgrid - mean) / std) ** 2 / 2) # Make sure sum of values in gaussian kernel equals 1. kernel = kernel / torch.sum(kernel) pad_length = (math.floor( (shape[-1]-kernel_size[-1])/2), math.floor((shape[-1]-kernel_size[-1])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-2]-kernel_size[-2])/2), math.floor((shape[-3]-kernel_size[-3])/2), math.floor((shape[-3]-kernel_size[-3])/2)) kernel = functional.pad(kernel, pad_length) assert kernel.shape == shape[-3:] return kernel ''' # Reshape to depthwise convolutional weight kernel = kernel.view(1, 1, *kernel.size()) kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) self.register_buffer('weight', kernel) self.groups = channels if dim == 1: self.conv = functional.conv1d elif dim == 2: self.conv = functional.conv2d elif dim == 3: self.conv = functional.conv3d else: raise RuntimeError( 'Only 1, 2 and 3 dimensions are supported. Received {}.'.format( dim) ) ''' class NoiseGenerator(): def __init__(self, alpha: float = 0.0, shared_noise_across_chunks: bool = False, mode="vanilla", forward_steps: int = 850, radius: List[float] = None) -> None: self.mode = mode self.alpha = alpha self.shared_noise_across_chunks = shared_noise_across_chunks self.forward_steps = forward_steps self.radius = radius def set_seed(self, seed: int): self.seed = seed def reset_seed(self, seed: int): pass def reset_noise_generator_state(self): if hasattr(self, "e_shared"): del self.e_shared def sample_noise(self, z_0: torch.tensor = None, shape=None, device=None, dtype=None, generator=None, content=None): assert (z_0 is not None) != ( shape is not None), f"either z_0 must be None, or shape must be None. Both provided." kwargs = {} noise = torch.randn(shape, **kwargs) if z_0 is None: if device is not None: kwargs["device"] = device if dtype is not None: kwargs["dtype"] = dtype else: kwargs["device"] = z_0.device kwargs["dtype"] = z_0.dtype shape = z_0.shape if generator is not None: kwargs["generator"] = generator B, F, C, W, H = shape if F == 4 and C > 4: frame_idx = 2 F, C = C, F else: frame_idx = 1 if "mixed_noise" in self.mode: shape_per_frame = [dim for dim in shape] shape_per_frame[frame_idx] = 1 zero_mean = torch.zeros( shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) std = torch.ones( shape_per_frame, device=kwargs["device"], dtype=kwargs["dtype"]) alpha = self.alpha std_coeff_shared = (alpha**2) / (1 + alpha**2) if self.shared_noise_across_chunks and hasattr(self, "e_shared"): e_shared = self.e_shared else: e_shared = torch.normal(mean=zero_mean, std=sqrt( std_coeff_shared)*std, generator=kwargs["generator"] if "generator" in kwargs else None) if self.shared_noise_across_chunks: self.e_shared = e_shared e_inds = [] for frame in range(shape[frame_idx]): std_coeff_ind = 1 / (1 + alpha**2) e_ind = torch.normal( mean=zero_mean, std=sqrt(std_coeff_ind)*std, generator=kwargs["generator"] if "generator" in kwargs else None) e_inds.append(e_ind) noise = torch.cat( [e_shared + e_ind for e_ind in e_inds], dim=frame_idx) if "consistI2V" in self.mode and content is not None: # if self.mode == "mixed_noise_consistI2V", we will use 'noise' from 'mixed_noise'. Otherwise, it is randn noise. if frame_idx == 1: assert content.shape[0] == noise.shape[0] and content.shape[2:] == noise.shape[2:] content = torch.concat([content, content[:, -1:].repeat( 1, noise.shape[1]-content.shape[1], 1, 1, 1)], dim=1) noise = rearrange(noise, "B F C W H -> (B C) F W H") content = rearrange(content, "B F C W H -> (B C) F W H") else: assert content.shape[:2] == noise.shape[: 2] and content.shape[3:] == noise.shape[3:] content = torch.concat( [content, content[:, :, -1:].repeat(1, 1, noise.shape[2]-content.shape[2], 1, 1)], dim=2) noise = rearrange(noise, "B C F W H -> (B C) F W H") content = rearrange(content, "B C F W H -> (B C) F W H") # TODO implement DDPM_forward using diffusers framework ''' content_noisy = ddpm_forward( content, noise, self.forward_steps) ''' # A 2D low pass filter was given in the blog: # see https://pytorch.org/blog/the-torch.fft-module-accelerated-fast-fourier-transforms-with-autograd-in-pyTorch/ # alternative # do we have to specify more (s,dim,norm?) noise_fft = fft.fftn(noise) content_noisy_fft = fft.fftn(content_noisy) # shift low frequency parts to center noise_fft_shifted = fft.fftshift(noise_fft) content_noisy_fft_shifted = fft.fftshift(content_noisy_fft) # create gaussian low pass filter 'gaussian_low_pass_filter' (specify std!) # mask out high frequencies using 'cutoff_frequence', something like gaussian_low_pass_filter[freq > cut_off_frequency] = 0.0 # TODO define 'gaussian_low_pass_filter', apply frequency cutoff filter using self.cutoff_frequency. We need to apply fft.fftshift too probably. # TODO what exactly is the "normalized space-time stop frequency" used for the cutoff? gaussian_3d = gaussian_smoothing_kernel(noise_fft.shape, kernel_size=( noise_fft.shape[-3], noise_fft.shape[-2], noise_fft.shape[-1]), sigma=1, dim=3).to(noise.device) # define cutoff frequency around the kernel center # TODO define center and cut off radius, e.g. somethink like gaussian_3d[...,:c_x-r_x,:c_y-r_y:,:c_z-r_z] = 0.0 and gaussian_3d[...,c_x+r_x:,c_y+r_y:,c_z+r_z:] = 0.0 # as we have 16 x 32 x 32, center should be (7.5,15.5,15.5) radius = self.radius # TODO we need to use rounding (ceil?) gaussian_3d[:center[0]-radius[0], :center[1] - radius[1], :center[2]-radius[2]] = 0.0 gaussian_3d[center[0]+radius[0]:, center[1]+radius[1]:, center[2]+radius[2]:] = 0.0 noise_fft_shifted_hp = noise_fft_shifted * (1 - gaussian_3d) content_noisy_fft_shifted_lp = content_noisy_fft_shifted * gaussian_3d noise = fft.ifftn(fft.ifftshift( noise_fft_shifted_hp+content_noisy_fft_shifted_lp)) if frame_idx == 1: noise = rearrange( noise, "(B C) F W H -> B F C W H", B=B) else: noise = rearrange( noise, "(B C) F W H -> B C F W H", B=B) assert noise.shape == shape return noise