Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
import torch.nn as nn | |
import torchvision | |
EPS = 1e-7 | |
class ConfNet(nn.Module): | |
def __init__(self, cin=3, cout=1, zdim=128, nf=64): | |
super(ConfNet, self).__init__() | |
## downsampling | |
network = [ | |
nn.Conv2d(cin, nf, kernel_size=4, stride=2, padding=1, bias=False), # 64x64 -> 32x32 | |
nn.GroupNorm(16, nf), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 16x16 | |
nn.GroupNorm(16*2, nf*2), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*2, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 8x8 | |
nn.GroupNorm(16*4, nf*4), | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*4, nf*8, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 4x4 | |
nn.LeakyReLU(0.2, inplace=True), | |
nn.Conv2d(nf*8, zdim, kernel_size=4, stride=1, padding=0, bias=False), # 4x4 -> 1x1 | |
nn.ReLU(inplace=True)] | |
## upsampling | |
network += [ | |
nn.ConvTranspose2d(zdim, nf*8, kernel_size=4, padding=0, bias=False), # 1x1 -> 4x4 | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(nf*8, nf*4, kernel_size=4, stride=2, padding=1, bias=False), # 4x4 -> 8x8 | |
nn.GroupNorm(16*4, nf*4), | |
nn.ReLU(inplace=True), | |
nn.ConvTranspose2d(nf*4, nf*2, kernel_size=4, stride=2, padding=1, bias=False), # 8x8 -> 16x16 | |
nn.GroupNorm(16*2, nf*2), | |
nn.ReLU(inplace=True)] | |
self.network = nn.Sequential(*network) | |
# ! only the symmetric confidence is required | |
# out_net1 = [ | |
# nn.ConvTranspose2d(nf*2, nf, kernel_size=4, stride=2, padding=1, bias=False), # 16x16 -> 32x32 | |
# nn.GroupNorm(16, nf), | |
# nn.ReLU(inplace=True), | |
# nn.ConvTranspose2d(nf, nf, kernel_size=4, stride=2, padding=1, bias=False), # 32x32 -> 64x64 | |
# nn.GroupNorm(16, nf), | |
# nn.ReLU(inplace=True), | |
# nn.Conv2d(nf, 2, kernel_size=5, stride=1, padding=2, bias=False), # 64x64 | |
# # nn.Conv2d(nf, 1, kernel_size=5, stride=1, padding=2, bias=False), # 64x64 | |
# nn.Softplus() | |
# ] | |
# self.out_net1 = nn.Sequential(*out_net1) | |
# ! for perceptual loss | |
out_net2 = [nn.Conv2d(nf*2, 2, kernel_size=3, stride=1, padding=1, bias=False), # 16x16 | |
nn.Softplus() | |
# nn.Sigmoid() | |
] | |
self.out_net2 = nn.Sequential(*out_net2) | |
def forward(self, input): | |
out = self.network(input) | |
# return self.out_net1(out) | |
return self.out_net2(out) | |
# return self.out_net1(out), self.out_net2(out) |