r3gm's picture
Upload 224 files
c2dad70
import torch
import torch.nn as nn
class TFC(nn.Module):
def __init__(self, c, l, k, norm):
super(TFC, self).__init__()
self.H = nn.ModuleList()
for i in range(l):
self.H.append(
nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
norm(c),
nn.ReLU(),
)
)
def forward(self, x):
for h in self.H:
x = h(x)
return x
class DenseTFC(nn.Module):
def __init__(self, c, l, k, norm):
super(DenseTFC, self).__init__()
self.conv = nn.ModuleList()
for i in range(l):
self.conv.append(
nn.Sequential(
nn.Conv2d(in_channels=c, out_channels=c, kernel_size=k, stride=1, padding=k // 2),
norm(c),
nn.ReLU(),
)
)
def forward(self, x):
for layer in self.conv[:-1]:
x = torch.cat([layer(x), x], 1)
return self.conv[-1](x)
class TFC_TDF(nn.Module):
def __init__(self, c, l, f, k, bn, dense=False, bias=True, norm=nn.BatchNorm2d):
super(TFC_TDF, self).__init__()
self.use_tdf = bn is not None
self.tfc = DenseTFC(c, l, k, norm) if dense else TFC(c, l, k, norm)
if self.use_tdf:
if bn == 0:
self.tdf = nn.Sequential(
nn.Linear(f, f, bias=bias),
norm(c),
nn.ReLU()
)
else:
self.tdf = nn.Sequential(
nn.Linear(f, f // bn, bias=bias),
norm(c),
nn.ReLU(),
nn.Linear(f // bn, f, bias=bias),
norm(c),
nn.ReLU()
)
def forward(self, x):
x = self.tfc(x)
return x + self.tdf(x) if self.use_tdf else x