Spaces:
Build error
Build error
Commit
·
cd97fb0
1
Parent(s):
2f0673d
Upload srm_conv.py
Browse files- srm_conv.py +121 -0
srm_conv.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
class SRMConv2d_simple(nn.Module):
|
9 |
+
|
10 |
+
def __init__(self, inc=3, learnable=False):
|
11 |
+
super(SRMConv2d_simple, self).__init__()
|
12 |
+
self.truc = nn.Hardtanh(-3, 3)
|
13 |
+
kernel = self._build_kernel(inc) # (3,3,5,5)
|
14 |
+
self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
|
15 |
+
# self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
|
16 |
+
|
17 |
+
def forward(self, x):
|
18 |
+
'''
|
19 |
+
x: imgs (Batch, H, W, 3)
|
20 |
+
'''
|
21 |
+
out = F.conv2d(x, self.kernel, stride=1, padding=2)
|
22 |
+
out = self.truc(out)
|
23 |
+
|
24 |
+
return out
|
25 |
+
|
26 |
+
def _build_kernel(self, inc):
|
27 |
+
# filter1: KB
|
28 |
+
filter1 = [[0, 0, 0, 0, 0],
|
29 |
+
[0, -1, 2, -1, 0],
|
30 |
+
[0, 2, -4, 2, 0],
|
31 |
+
[0, -1, 2, -1, 0],
|
32 |
+
[0, 0, 0, 0, 0]]
|
33 |
+
# filter2:KV
|
34 |
+
filter2 = [[-1, 2, -2, 2, -1],
|
35 |
+
[2, -6, 8, -6, 2],
|
36 |
+
[-2, 8, -12, 8, -2],
|
37 |
+
[2, -6, 8, -6, 2],
|
38 |
+
[-1, 2, -2, 2, -1]]
|
39 |
+
# filter3:hor 2rd
|
40 |
+
filter3 = [[0, 0, 0, 0, 0],
|
41 |
+
[0, 0, 0, 0, 0],
|
42 |
+
[0, 1, -2, 1, 0],
|
43 |
+
[0, 0, 0, 0, 0],
|
44 |
+
[0, 0, 0, 0, 0]]
|
45 |
+
|
46 |
+
filter1 = np.asarray(filter1, dtype=float) / 4.
|
47 |
+
filter2 = np.asarray(filter2, dtype=float) / 12.
|
48 |
+
filter3 = np.asarray(filter3, dtype=float) / 2.
|
49 |
+
# statck the filters
|
50 |
+
filters = [[filter1],#, filter1, filter1],
|
51 |
+
[filter2],#, filter2, filter2],
|
52 |
+
[filter3]]#, filter3, filter3]] # (3,3,5,5)
|
53 |
+
filters = np.array(filters)
|
54 |
+
filters = np.repeat(filters, inc, axis=1)
|
55 |
+
filters = torch.FloatTensor(filters) # (3,3,5,5)
|
56 |
+
return filters
|
57 |
+
|
58 |
+
class SRMConv2d_Separate(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, inc, outc, learnable=False):
|
61 |
+
super(SRMConv2d_Separate, self).__init__()
|
62 |
+
self.inc = inc
|
63 |
+
self.truc = nn.Hardtanh(-3, 3)
|
64 |
+
kernel = self._build_kernel(inc) # (3,3,5,5)
|
65 |
+
self.kernel = nn.Parameter(data=kernel, requires_grad=learnable)
|
66 |
+
# self.hor_kernel = self._build_kernel().transpose(0,1,3,2)
|
67 |
+
self.out_conv = nn.Sequential(
|
68 |
+
nn.Conv2d(3*inc, outc, 1, 1, 0, 1, 1, bias=False),
|
69 |
+
nn.BatchNorm2d(outc),
|
70 |
+
nn.ReLU(inplace=True)
|
71 |
+
)
|
72 |
+
|
73 |
+
for ly in self.out_conv.children():
|
74 |
+
if isinstance(ly, nn.Conv2d):
|
75 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
76 |
+
|
77 |
+
def forward(self, x):
|
78 |
+
'''
|
79 |
+
x: imgs (Batch, H, W, 3)
|
80 |
+
'''
|
81 |
+
out = F.conv2d(x, self.kernel, stride=1, padding=2, groups=self.inc)
|
82 |
+
out = self.truc(out)
|
83 |
+
out = self.out_conv(out)
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
def _build_kernel(self, inc):
|
88 |
+
# filter1: KB
|
89 |
+
filter1 = [[0, 0, 0, 0, 0],
|
90 |
+
[0, -1, 2, -1, 0],
|
91 |
+
[0, 2, -4, 2, 0],
|
92 |
+
[0, -1, 2, -1, 0],
|
93 |
+
[0, 0, 0, 0, 0]]
|
94 |
+
# filter2:KV
|
95 |
+
filter2 = [[-1, 2, -2, 2, -1],
|
96 |
+
[2, -6, 8, -6, 2],
|
97 |
+
[-2, 8, -12, 8, -2],
|
98 |
+
[2, -6, 8, -6, 2],
|
99 |
+
[-1, 2, -2, 2, -1]]
|
100 |
+
# # filter3:hor 2rd
|
101 |
+
filter3 = [[0, 0, 0, 0, 0],
|
102 |
+
[0, 0, 0, 0, 0],
|
103 |
+
[0, 1, -2, 1, 0],
|
104 |
+
[0, 0, 0, 0, 0],
|
105 |
+
[0, 0, 0, 0, 0]]
|
106 |
+
|
107 |
+
filter1 = np.asarray(filter1, dtype=float) / 4.
|
108 |
+
filter2 = np.asarray(filter2, dtype=float) / 12.
|
109 |
+
filter3 = np.asarray(filter3, dtype=float) / 2.
|
110 |
+
# statck the filters
|
111 |
+
filters = [[filter1],#, filter1, filter1],
|
112 |
+
[filter2],#, filter2, filter2],
|
113 |
+
[filter3]]#, filter3, filter3]] # (3,3,5,5)
|
114 |
+
filters = np.array(filters)
|
115 |
+
# filters = np.repeat(filters, inc, axis=1)
|
116 |
+
filters = np.repeat(filters, inc, axis=0)
|
117 |
+
filters = torch.FloatTensor(filters) # (3,3,5,5)
|
118 |
+
# print(filters.size())
|
119 |
+
return filters
|
120 |
+
|
121 |
+
|