asdasdasdasd commited on
Commit
cd97fb0
·
1 Parent(s): 2f0673d

Upload srm_conv.py

Browse files
Files changed (1) hide show
  1. srm_conv.py +121 -0
srm_conv.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ class SRMConv2d_simple(nn.Module):
9
+
10
+ def __init__(self, inc=3, learnable=False):
11
+ super(SRMConv2d_simple, self).__init__()
12
+ self.truc = nn.Hardtanh(-3, 3)
13
+ kernel = self._build_kernel(inc) # (3,3,5,5)
14
+ self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
15
+ # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
16
+
17
+ def forward(self, x):
18
+ '''
19
+ x: imgs (Batch, H, W, 3)
20
+ '''
21
+ out = F.conv2d(x, self.kernel, stride=1, padding=2)
22
+ out = self.truc(out)
23
+
24
+ return out
25
+
26
+ def _build_kernel(self, inc):
27
+ # filter1: KB
28
+ filter1 = [[0, 0, 0, 0, 0],
29
+ [0, -1, 2, -1, 0],
30
+ [0, 2, -4, 2, 0],
31
+ [0, -1, 2, -1, 0],
32
+ [0, 0, 0, 0, 0]]
33
+ # filter2:KV
34
+ filter2 = [[-1, 2, -2, 2, -1],
35
+ [2, -6, 8, -6, 2],
36
+ [-2, 8, -12, 8, -2],
37
+ [2, -6, 8, -6, 2],
38
+ [-1, 2, -2, 2, -1]]
39
+ # filter3:hor 2rd
40
+ filter3 = [[0, 0, 0, 0, 0],
41
+ [0, 0, 0, 0, 0],
42
+ [0, 1, -2, 1, 0],
43
+ [0, 0, 0, 0, 0],
44
+ [0, 0, 0, 0, 0]]
45
+
46
+ filter1 = np.asarray(filter1, dtype=float) / 4.
47
+ filter2 = np.asarray(filter2, dtype=float) / 12.
48
+ filter3 = np.asarray(filter3, dtype=float) / 2.
49
+ # statck the filters
50
+ filters = [[filter1],#, filter1, filter1],
51
+ [filter2],#, filter2, filter2],
52
+ [filter3]]#, filter3, filter3]] # (3,3,5,5)
53
+ filters = np.array(filters)
54
+ filters = np.repeat(filters, inc, axis=1)
55
+ filters = torch.FloatTensor(filters) # (3,3,5,5)
56
+ return filters
57
+
58
+ class SRMConv2d_Separate(nn.Module):
59
+
60
+ def __init__(self, inc, outc, learnable=False):
61
+ super(SRMConv2d_Separate, self).__init__()
62
+ self.inc = inc
63
+ self.truc = nn.Hardtanh(-3, 3)
64
+ kernel = self._build_kernel(inc) # (3,3,5,5)
65
+ self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
66
+ # self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
67
+ self.out_conv = nn.Sequential(
68
+ nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
69
+ nn.BatchNorm2d(outc),
70
+ nn.ReLU(inplace=True)
71
+ )
72
+
73
+ for ly in self.out_conv.children():
74
+ if isinstance(ly, nn.Conv2d):
75
+ nn.init.kaiming_normal_(ly.weight, a=1)
76
+
77
+ def forward(self, x):
78
+ '''
79
+ x: imgs (Batch, H, W, 3)
80
+ '''
81
+ out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
82
+ out = self.truc(out)
83
+ out = self.out_conv(out)
84
+
85
+ return out
86
+
87
+ def _build_kernel(self, inc):
88
+ # filter1: KB
89
+ filter1 = [[0, 0, 0, 0, 0],
90
+ [0, -1, 2, -1, 0],
91
+ [0, 2, -4, 2, 0],
92
+ [0, -1, 2, -1, 0],
93
+ [0, 0, 0, 0, 0]]
94
+ # filter2:KV
95
+ filter2 = [[-1, 2, -2, 2, -1],
96
+ [2, -6, 8, -6, 2],
97
+ [-2, 8, -12, 8, -2],
98
+ [2, -6, 8, -6, 2],
99
+ [-1, 2, -2, 2, -1]]
100
+ # # filter3:hor 2rd
101
+ filter3 = [[0, 0, 0, 0, 0],
102
+ [0, 0, 0, 0, 0],
103
+ [0, 1, -2, 1, 0],
104
+ [0, 0, 0, 0, 0],
105
+ [0, 0, 0, 0, 0]]
106
+
107
+ filter1 = np.asarray(filter1, dtype=float) / 4.
108
+ filter2 = np.asarray(filter2, dtype=float) / 12.
109
+ filter3 = np.asarray(filter3, dtype=float) / 2.
110
+ # statck the filters
111
+ filters = [[filter1],#, filter1, filter1],
112
+ [filter2],#, filter2, filter2],
113
+ [filter3]]#, filter3, filter3]] # (3,3,5,5)
114
+ filters = np.array(filters)
115
+ # filters = np.repeat(filters, inc, axis=1)
116
+ filters = np.repeat(filters, inc, axis=0)
117
+ filters = torch.FloatTensor(filters) # (3,3,5,5)
118
+ # print(filters.size())
119
+ return filters
120
+
121
+