# -*- coding: utf-8 -*- | |
# @Time : 2024/7/21 下午5:11 | |
# @Author : xiaoshun | |
# @Email : 3038523973@qq.com | |
# @File : scnn.py | |
# @Software: PyCharm | |
# 论文地址:https://www.sciencedirect.com/science/article/abs/pii/S0924271624000352?via%3Dihub#fn1 | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class SCNN(nn.Module): | |
def __init__(self, in_channels=3, num_classes=2, dropout_p=0.5): | |
super().__init__() | |
self.conv1 = nn.Conv2d(in_channels, 64, kernel_size=1) | |
self.conv2 = nn.Conv2d(64, num_classes, kernel_size=1) | |
self.conv3 = nn.Conv2d(num_classes, num_classes, kernel_size=3, padding=1) | |
self.dropout = nn.Dropout2d(p=dropout_p) | |
def forward(self, x): | |
x = F.relu(self.conv1(x)) | |
x = self.dropout(x) | |
x = self.conv2(x) | |
x = self.conv3(x) | |
return x | |
if __name__ == '__main__': | |
model = SCNN(num_classes=7) | |
fake_img = torch.randn((2, 3, 224, 224)) | |
out = model(fake_img) | |
print(out.shape) | |
# torch.Size([2, 7, 224, 224]) |