Sapir's picture
refactor
c811a04
raw
history blame
1.58 kB
from typing import Tuple, Union
import torch
import torch.nn as nn
class CausalConv3d(nn.Module):
def __init__(
self,
in_channels,
out_channels,
kernel_size: int = 3,
stride: Union[int, Tuple[int]] = 1,
**kwargs,
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0]
dilation = kwargs.pop("dilation", 1)
dilation = (dilation, 1, 1)
height_pad = kernel_size[1] // 2
width_pad = kernel_size[2] // 2
padding = (0, height_pad, width_pad)
self.conv = nn.Conv3d(
in_channels,
out_channels,
kernel_size,
stride=stride,
dilation=dilation,
padding=padding,
padding_mode="zeros",
)
def forward(self, x, causal: bool = True):
if causal:
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1))
x = torch.concatenate((first_frame_pad, x), dim=2)
else:
first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
last_frame_pad = x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))
x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2)
x = self.conv(x)
return x
@property
def weight(self):
return self.conv.weight