import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from typing import Tuple, Literal from functools import partial from core.attention import MemEffAttention class MVAttention(nn.Module): def __init__( self, dim: int, num_heads: int = 8, qkv_bias: bool = False, proj_bias: bool = True, attn_drop: float = 0.0, proj_drop: float = 0.0, groups: int = 32, eps: float = 1e-5, residual: bool = True, skip_scale: float = 1, num_frames: int = 4, # WARN: hardcoded! ): super().__init__() self.residual = residual self.skip_scale = skip_scale self.num_frames = num_frames self.norm = nn.GroupNorm(num_groups=groups, num_channels=dim, eps=eps, affine=True) self.attn = MemEffAttention(dim, num_heads, qkv_bias, proj_bias, attn_drop, proj_drop) def forward(self, x): # x: [B*V, C, H, W] BV, C, H, W = x.shape B = BV // self.num_frames # assert BV % self.num_frames == 0 res = x x = self.norm(x) x = x.reshape(B, self.num_frames, C, H, W).permute(0, 1, 3, 4, 2).reshape(B, -1, C) x = self.attn(x) x = x.reshape(B, self.num_frames, H, W, C).permute(0, 1, 4, 2, 3).reshape(BV, C, H, W) if self.residual: x = (x + res) * self.skip_scale return x class ResnetBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, resample: Literal['default', 'up', 'down'] = 'default', groups: int = 32, eps: float = 1e-5, skip_scale: float = 1, # multiplied to output ): super().__init__() self.in_channels = in_channels self.out_channels = out_channels self.skip_scale = skip_scale self.norm1 = nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) self.norm2 = nn.GroupNorm(num_groups=groups, num_channels=out_channels, eps=eps, affine=True) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) self.act = F.silu self.resample = None if resample == 'up': self.resample = partial(F.interpolate, scale_factor=2.0, mode="nearest") elif resample == 'down': self.resample = nn.AvgPool2d(kernel_size=2, stride=2) self.shortcut = nn.Identity() if self.in_channels != self.out_channels: self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=True) def forward(self, x): res = x x = self.norm1(x) x = self.act(x) if self.resample: res = self.resample(res) x = self.resample(x) x = self.conv1(x) x = self.norm2(x) x = self.act(x) x = self.conv2(x) x = (x + self.shortcut(res)) * self.skip_scale return x class DownBlock(nn.Module): def __init__( self, in_channels: int, out_channels: int, num_layers: int = 1, downsample: bool = True, attention: bool = True, attention_heads: int = 16, skip_scale: float = 1, ): super().__init__() nets = [] attns = [] for i in range(num_layers): in_channels = in_channels if i == 0 else out_channels nets.append(ResnetBlock(in_channels, out_channels, skip_scale=skip_scale)) if attention: attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) else: attns.append(None) self.nets = nn.ModuleList(nets) self.attns = nn.ModuleList(attns) self.downsample = None if downsample: self.downsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=2, padding=1) def forward(self, x): xs = [] for attn, net in zip(self.attns, self.nets): x = net(x) if attn: x = attn(x) xs.append(x) if self.downsample: x = self.downsample(x) xs.append(x) return x, xs class MidBlock(nn.Module): def __init__( self, in_channels: int, num_layers: int = 1, attention: bool = True, attention_heads: int = 16, skip_scale: float = 1, ): super().__init__() nets = [] attns = [] # first layer nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) # more layers for i in range(num_layers): nets.append(ResnetBlock(in_channels, in_channels, skip_scale=skip_scale)) if attention: attns.append(MVAttention(in_channels, attention_heads, skip_scale=skip_scale)) else: attns.append(None) self.nets = nn.ModuleList(nets) self.attns = nn.ModuleList(attns) def forward(self, x): x = self.nets[0](x) for attn, net in zip(self.attns, self.nets[1:]): if attn: x = attn(x) x = net(x) return x class UpBlock(nn.Module): def __init__( self, in_channels: int, prev_out_channels: int, out_channels: int, num_layers: int = 1, upsample: bool = True, attention: bool = True, attention_heads: int = 16, skip_scale: float = 1, ): super().__init__() nets = [] attns = [] for i in range(num_layers): cin = in_channels if i == 0 else out_channels cskip = prev_out_channels if (i == num_layers - 1) else out_channels nets.append(ResnetBlock(cin + cskip, out_channels, skip_scale=skip_scale)) if attention: attns.append(MVAttention(out_channels, attention_heads, skip_scale=skip_scale)) else: attns.append(None) self.nets = nn.ModuleList(nets) self.attns = nn.ModuleList(attns) self.upsample = None if upsample: self.upsample = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x, xs): for attn, net in zip(self.attns, self.nets): res_x = xs[-1] xs = xs[:-1] x = torch.cat([x, res_x], dim=1) x = net(x) if attn: x = attn(x) if self.upsample: x = F.interpolate(x, scale_factor=2.0, mode='nearest') x = self.upsample(x) return x # it could be asymmetric! class UNet(nn.Module): def __init__( self, in_channels: int = 3, out_channels: int = 3, down_channels: Tuple[int, ...] = (64, 128, 256, 512, 1024), down_attention: Tuple[bool, ...] = (False, False, False, True, True), mid_attention: bool = True, up_channels: Tuple[int, ...] = (1024, 512, 256), up_attention: Tuple[bool, ...] = (True, True, False), layers_per_block: int = 2, skip_scale: float = np.sqrt(0.5), ): super().__init__() # first self.conv_in = nn.Conv2d(in_channels, down_channels[0], kernel_size=3, stride=1, padding=1) # down down_blocks = [] cout = down_channels[0] for i in range(len(down_channels)): cin = cout cout = down_channels[i] down_blocks.append(DownBlock( cin, cout, num_layers=layers_per_block, downsample=(i != len(down_channels) - 1), # not final layer attention=down_attention[i], skip_scale=skip_scale, )) self.down_blocks = nn.ModuleList(down_blocks) # mid self.mid_block = MidBlock(down_channels[-1], attention=mid_attention, skip_scale=skip_scale) # up up_blocks = [] cout = up_channels[0] for i in range(len(up_channels)): cin = cout cout = up_channels[i] cskip = down_channels[max(-2 - i, -len(down_channels))] # for assymetric up_blocks.append(UpBlock( cin, cskip, cout, num_layers=layers_per_block + 1, # one more layer for up upsample=(i != len(up_channels) - 1), # not final layer attention=up_attention[i], skip_scale=skip_scale, )) self.up_blocks = nn.ModuleList(up_blocks) # last self.norm_out = nn.GroupNorm(num_channels=up_channels[-1], num_groups=32, eps=1e-5) self.conv_out = nn.Conv2d(up_channels[-1], out_channels, kernel_size=3, stride=1, padding=1) def forward(self, x): # x: [B, Cin, H, W] # first x = self.conv_in(x) # down xss = [x] for block in self.down_blocks: x, xs = block(x) xss.extend(xs) # mid x = self.mid_block(x) # up for block in self.up_blocks: xs = xss[-len(block.nets):] xss = xss[:-len(block.nets)] x = block(x, xs) # last x = self.norm_out(x) x = F.silu(x) x = self.conv_out(x) # [B, Cout, H', W'] return x