import torch | |
from torch import nn | |
from torchvision import models | |
class DeePixBiS(nn.Module): | |
def __init__(self, pretrained=True): | |
super().__init__() | |
weights = pretrained if pretrained else None | |
dense = models.densenet161(weights=weights) | |
features = list(dense.features.children()) | |
self.enc = nn.Sequential(*features[:8]) | |
self.dec = nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0) | |
self.linear = nn.Linear(14 * 14, 1) | |
def forward(self, x): | |
enc = self.enc(x) | |
dec = self.dec(enc) | |
out_map = torch.sigmoid(dec) | |
# print(out_map.shape) | |
out = self.linear(out_map.view(-1, 14 * 14)) | |
out = torch.sigmoid(out) | |
out = torch.flatten(out) | |
return out_map, out | |