AlzheimerDetection / AlzheimerTriMatterNet.py
Jiranuwat's picture
Upload 10 files
201936b verified
import torch.nn as nn
import torchvision.models as models
from Resnet18 import *
class AlzheimerTriMatterNet(nn.Module):
def __init__(self):
super(AlzheimerTriMatterNet, self).__init__()
self.numclass = 4
self.whitematter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
self.graymatter_resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
self.resnet18_model = ResNet18(img_channels=3, num_layers=18, block=BasicBlock, num_classes=4)
self.global_classification_head = nn.Sequential(
nn.Linear(512*3,self.numclass),
nn.Softmax(dim=1),
)
def forward(self, whitematter, graymatter, original):
white_output = self.whitematter_resnet18_model(whitematter)
gray_output = self.graymatter_resnet18_model(graymatter)
origin_output = self.resnet18_model(original)
combined_tensor = torch.cat(( white_output, gray_output, origin_output), dim=1)
output = self.global_classification_head(combined_tensor)
return output