pyramid-flow / video_vae /modeling_causal_conv.py
multimodalart's picture
Upload 33 files
f0533a5 verified
raw
history blame
No virus
4.81 kB
from typing import Tuple, Union
import torch
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
import torch.nn.functional as F
from collections import deque
from einops import rearrange
from timm.models.layers import trunc_normal_
from IPython import embed
from torch import Tensor
from utils import (
is_context_parallel_initialized,
get_context_parallel_group,
get_context_parallel_world_size,
get_context_parallel_rank,
get_context_parallel_group_rank,
)
from .context_parallel_ops import (
conv_scatter_to_context_parallel_region,
conv_gather_from_context_parallel_region,
cp_pass_from_previous_rank,
)
def divisible_by(num, den):
return (num % den) == 0
def cast_tuple(t, length = 1):
return t if isinstance(t, tuple) else ((t,) * length)
def is_odd(n):
return not divisible_by(n, 2)
class CausalGroupNorm(nn.GroupNorm):
def forward(self, x: Tensor) -> Tensor:
t = x.shape[2]
x = rearrange(x, 'b c t h w -> (b t) c h w')
x = super().forward(x)
x = rearrange(x, '(b t) c h w -> b c t h w', t=t)
return x
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: Union[int, Tuple[int, int, int]],
stride: Union[int, Tuple[int, int, int]] = 1,
pad_mode: str ='constant',
**kwargs
):
super().__init__()
if isinstance(kernel_size, int):
kernel_size = cast_tuple(kernel_size, 3)
time_kernel_size, height_kernel_size, width_kernel_size = kernel_size
self.time_kernel_size = time_kernel_size
assert is_odd(height_kernel_size) and is_odd(width_kernel_size)
dilation = kwargs.pop('dilation', 1)
self.pad_mode = pad_mode
if isinstance(stride, int):
stride = (stride, 1, 1)
time_pad = dilation * (time_kernel_size - 1)
height_pad = height_kernel_size // 2
width_pad = width_kernel_size // 2
self.temporal_stride = stride[0]
self.time_pad = time_pad
self.time_causal_padding = (width_pad, width_pad, height_pad, height_pad, time_pad, 0)
self.time_uncausal_padding = (width_pad, width_pad, height_pad, height_pad, 0, 0)
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=0, dilation=dilation, **kwargs)
self.cache_front_feat = deque()
def _clear_context_parallel_cache(self):
del self.cache_front_feat
self.cache_front_feat = deque()
def _init_weights(self, m):
if isinstance(m, (nn.Linear, nn.Conv2d, nn.Conv3d)):
trunc_normal_(m.weight, std=.02)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, (nn.LayerNorm, nn.GroupNorm)):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def context_parallel_forward(self, x):
x = cp_pass_from_previous_rank(x, dim=2, kernel_size=self.time_kernel_size)
x = F.pad(x, self.time_uncausal_padding, mode='constant')
cp_rank = get_context_parallel_rank()
if cp_rank != 0:
if self.temporal_stride == 2 and self.time_kernel_size == 3:
x = x[:,:,1:]
x = self.conv(x)
return x
def forward(self, x, is_init_image=True, temporal_chunk=False):
# temporal_chunk: whether to use the temporal chunk
if is_context_parallel_initialized():
return self.context_parallel_forward(x)
pad_mode = self.pad_mode if self.time_pad < x.shape[2] else 'constant'
if not temporal_chunk:
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
else:
assert not self.training, "The feature cache should not be used in training"
if is_init_image:
# Encode the first chunk
x = F.pad(x, self.time_causal_padding, mode=pad_mode)
self._clear_context_parallel_cache()
self.cache_front_feat.append(x[:, :, -2:].clone().detach())
else:
x = F.pad(x, self.time_uncausal_padding, mode=pad_mode)
video_front_context = self.cache_front_feat.pop()
self._clear_context_parallel_cache()
if self.temporal_stride == 1 and self.time_kernel_size == 3:
x = torch.cat([video_front_context, x], dim=2)
elif self.temporal_stride == 2 and self.time_kernel_size == 3:
x = torch.cat([video_front_context[:,:,-1:], x], dim=2)
self.cache_front_feat.append(x[:, :, -2:].clone().detach())
x = self.conv(x)
return x