File size: 780 Bytes
9067733
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
import torch
from torch import nn
from torchvision import models


class DeePixBiS(nn.Module):
    def __init__(self, pretrained=True):
        super().__init__()
        weights = pretrained if pretrained else None
        dense = models.densenet161(weights=weights)
        features = list(dense.features.children())
        self.enc = nn.Sequential(*features[:8])
        self.dec = nn.Conv2d(384, 1, kernel_size=1, stride=1, padding=0)
        self.linear = nn.Linear(14 * 14, 1)

    def forward(self, x):
        enc = self.enc(x)
        dec = self.dec(enc)
        out_map = torch.sigmoid(dec)
        # print(out_map.shape)
        out = self.linear(out_map.view(-1, 14 * 14))
        out = torch.sigmoid(out)
        out = torch.flatten(out)
        return out_map, out