multimodalart HF staff commited on
Commit
ba02955
1 Parent(s): 35498d7

Create previewer/modules.py

Browse files
Files changed (1) hide show
  1. previewer/modules.py +36 -0
previewer/modules.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+ # Effnet 16x16 to 64x64 previewer
4
+ class Previewer(nn.Module):
5
+ def __init__(self, c_in=16, c_hidden=512, c_out=3):
6
+ super().__init__()
7
+ self.blocks = nn.Sequential(
8
+ nn.Conv2d(c_in, c_hidden, kernel_size=1), # 36 channels to 512 channels
9
+ nn.GELU(),
10
+ nn.BatchNorm2d(c_hidden),
11
+
12
+ nn.Conv2d(c_hidden, c_hidden, kernel_size=3, padding=1),
13
+ nn.GELU(),
14
+ nn.BatchNorm2d(c_hidden),
15
+
16
+ nn.ConvTranspose2d(c_hidden, c_hidden//2, kernel_size=2, stride=2), # 16 -> 32
17
+ nn.GELU(),
18
+ nn.BatchNorm2d(c_hidden//2),
19
+
20
+ nn.Conv2d(c_hidden//2, c_hidden//2, kernel_size=3, padding=1),
21
+ nn.GELU(),
22
+ nn.BatchNorm2d(c_hidden//2),
23
+
24
+ nn.ConvTranspose2d(c_hidden//2, c_hidden//4, kernel_size=2, stride=2), # 32 -> 64
25
+ nn.GELU(),
26
+ nn.BatchNorm2d(c_hidden//4),
27
+
28
+ nn.Conv2d(c_hidden//4, c_hidden//4, kernel_size=3, padding=1),
29
+ nn.GELU(),
30
+ nn.BatchNorm2d(c_hidden//4),
31
+
32
+ nn.Conv2d(c_hidden//4, c_out, kernel_size=1),
33
+ )
34
+
35
+ def forward(self, x):
36
+ return self.blocks(x)