import math from typing import Tuple, Union import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange class DualConv3d(nn.Module): def __init__( self, in_channels, out_channels, kernel_size, stride: Union[int, Tuple[int, int, int]] = 1, padding: Union[int, Tuple[int, int, int]] = 0, dilation: Union[int, Tuple[int, int, int]] = 1, groups=1, bias=True, ): super(DualConv3d, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 if isinstance(kernel_size, int): kernel_size = (kernel_size, kernel_size, kernel_size) if kernel_size == (1, 1, 1): raise ValueError("kernel_size must be greater than 1. Use make_linear_nd instead.") if isinstance(stride, int): stride = (stride, stride, stride) if isinstance(padding, int): padding = (padding, padding, padding) if isinstance(dilation, int): dilation = (dilation, dilation, dilation) # Set parameters for convolutions self.groups = groups self.bias = bias # Define the size of the channels after the first convolution intermediate_channels = out_channels if in_channels < out_channels else in_channels # Define parameters for the first convolution self.weight1 = nn.Parameter( torch.Tensor(intermediate_channels, in_channels // groups, 1, kernel_size[1], kernel_size[2]) ) self.stride1 = (1, stride[1], stride[2]) self.padding1 = (0, padding[1], padding[2]) self.dilation1 = (1, dilation[1], dilation[2]) if bias: self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) else: self.register_parameter("bias1", None) # Define parameters for the second convolution self.weight2 = nn.Parameter(torch.Tensor(out_channels, intermediate_channels // groups, kernel_size[0], 1, 1)) self.stride2 = (stride[0], 1, 1) self.padding2 = (padding[0], 0, 0) self.dilation2 = (dilation[0], 1, 1) if bias: self.bias2 = nn.Parameter(torch.Tensor(out_channels)) else: self.register_parameter("bias2", None) # Initialize weights and biases self.reset_parameters() def reset_parameters(self): nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) if self.bias: fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) bound1 = 1 / math.sqrt(fan_in1) nn.init.uniform_(self.bias1, -bound1, bound1) fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) bound2 = 1 / math.sqrt(fan_in2) nn.init.uniform_(self.bias2, -bound2, bound2) def forward(self, x, use_conv3d=False, skip_time_conv=False): if use_conv3d: return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) else: return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) def forward_with_3d(self, x, skip_time_conv): # First convolution x = F.conv3d(x, self.weight1, self.bias1, self.stride1, self.padding1, self.dilation1, self.groups) if skip_time_conv: return x # Second convolution x = F.conv3d(x, self.weight2, self.bias2, self.stride2, self.padding2, self.dilation2, self.groups) return x def forward_with_2d(self, x, skip_time_conv): b, c, d, h, w = x.shape # First 2D convolution x = rearrange(x, "b c d h w -> (b d) c h w") # Squeeze the depth dimension out of weight1 since it's 1 weight1 = self.weight1.squeeze(2) # Select stride, padding, and dilation for the 2D convolution stride1 = (self.stride1[1], self.stride1[2]) padding1 = (self.padding1[1], self.padding1[2]) dilation1 = (self.dilation1[1], self.dilation1[2]) x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) _, _, h, w = x.shape if skip_time_conv: x = rearrange(x, "(b d) c h w -> b c d h w", b=b) return x # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) # Reshape weight2 to match the expected dimensions for conv1d weight2 = self.weight2.squeeze(-1).squeeze(-1) # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution stride2 = self.stride2[0] padding2 = self.padding2[0] dilation2 = self.dilation2[0] x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) return x @property def weight(self): return self.weight2 def test_dual_conv3d_consistency(): # Initialize parameters in_channels = 3 out_channels = 5 kernel_size = (3, 3, 3) stride = (2, 2, 2) padding = (1, 1, 1) # Create an instance of the DualConv3d class dual_conv3d = DualConv3d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=True, ) # Example input tensor test_input = torch.randn(1, 3, 10, 10, 10) # Perform forward passes with both 3D and 2D settings output_conv3d = dual_conv3d(test_input, use_conv3d=True) output_2d = dual_conv3d(test_input, use_conv3d=False) # Assert that the outputs from both methods are sufficiently close assert torch.allclose( output_conv3d, output_2d, atol=1e-6 ), "Outputs are not consistent between 3D and 2D convolutions."