Spaces:
Running
Running
File size: 1,234 Bytes
201936b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import torch.nn as nn
import torchvision.models as models
from Resnet18 import *
class AlzheimerTriMatterNet(nn.Module):
def __init__(self):
super(AlzheimerTriMatterNet, self).__init__()
self.numclass = 4
self.whitematter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
self.graymatter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
self.resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
self.global_classification_head = nn.Sequential(
nn.Linear(512*3,self.numclass),
nn.Softmax(dim=1),
)
def forward(self, whitematter, graymatter, original):
white_output = self.whitematter_resnet18_model(whitematter)
gray_output = self.graymatter_resnet18_model(graymatter)
origin_output = self.resnet18_model(original)
combined_tensor = torch.cat(( white_output, gray_output, origin_output), dim=1)
output = self.global_classification_head(combined_tensor)
return output
|