File size: 2,097 Bytes
482ab8a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import torch
import torch.nn as nn
from einops import rearrange


class BayarConv2d(nn.Module):
    def __init__(
        self,
        in_channles: int,
        out_channels: int,
        kernel_size: int = 5,
        stride: int = 1,
        padding: int = 0,
        magnitude: float = 1.0,
    ):
        super().__init__()
        assert kernel_size > 1, "Bayar conv kernel size must be greater than 1"

        self.in_channels = in_channles
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.magnitude = magnitude

        self.center_weight = nn.Parameter(
            torch.ones(self.in_channels, self.out_channels, 1) * -1.0 * magnitude,
            requires_grad=False,
        )
        self.kernel_weight = nn.Parameter(
            torch.rand((self.in_channels, self.out_channels, kernel_size**2 - 1)),
            requires_grad=True,
        )

    def _constraint_weight(self):
        self.kernel_weight.data = self.kernel_weight.permute(2, 0, 1)
        self.kernel_weight.data = torch.div(
            self.kernel_weight.data, self.kernel_weight.data.sum(0)
        )
        self.kernel_weight.data = self.kernel_weight.permute(1, 2, 0) * self.magnitude
        center_idx = self.kernel_size**2 // 2
        full_kernel = torch.cat(
            [
                self.kernel_weight[:, :, :center_idx],
                self.center_weight,
                self.kernel_weight[:, :, center_idx:],
            ],
            dim=2,
        )
        full_kernel = rearrange(
            full_kernel, "ci co (kw kh) -> ci co kw kh", kw=self.kernel_size
        )
        return full_kernel

    def forward(self, x):
        x = nn.functional.conv2d(
            x, self._constraint_weight(), stride=self.stride, padding=self.padding
        )
        return x


if __name__ == "__main__":
    device = "cuda"
    bayer_conv2d = BayarConv2d(3, 3, 3, magnitude=1).to(device)
    bayer_conv2d._constraint_weight()
    i = torch.rand(16, 3, 16, 16).to(device)
    o = bayer_conv2d(i)