BlockNet10 / blocknet10.py
siddheshtv
readme citation url change
1db42e4
raw
history blame
2.63 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
from huggingface_hub import PyTorchModelHubMixin
class IntermediateBlock(nn.Module):
def __init__(self, in_channels, num_conv_layers, conv_params):
super(IntermediateBlock, self).__init__()
self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels, *conv_params) for _ in range(num_conv_layers)])
self.batch_norms = nn.ModuleList([nn.BatchNorm2d(conv_params[0]) for _ in range(num_conv_layers)])
out_channels = conv_params[0]
self.fc = nn.Linear(in_channels, out_channels)
def forward(self, x):
batch_size = x.size(0)
channel_means = x.mean(dim=[2, 3])
a = self.fc(channel_means)
x_out = torch.stack([F.leaky_relu(conv(x)) for conv in self.conv_layers], dim=-1).sum(dim=-1)
x_out = torch.stack([bn(x_out) for bn in self.batch_norms], dim=-1).sum(dim=-1)
return x_out * F.leaky_relu(a.view(batch_size, -1, 1, 1))
class OutputBlock(nn.Module):
def __init__(self, in_channels, num_classes, hidden_sizes=[]):
super(OutputBlock, self).__init__()
self.fc_layers = nn.ModuleList([nn.Linear(in_channels, hidden_sizes[0])] + [nn.Linear(hidden_sizes[i], hidden_sizes[i+1]) for i in range(len(hidden_sizes)-1)] + [nn.Linear(hidden_sizes[-1], num_classes)])
self.batch_norms = nn.ModuleList([nn.BatchNorm1d(size) for size in hidden_sizes])
def forward(self, x):
channel_means = x.mean(dim=[2, 3])
out = F.leaky_relu(channel_means)
for fc, bn in zip(self.fc_layers, self.batch_norms):
out = F.leaky_relu(bn(fc(out)))
return out
class CustomCIFAR10Net(nn.Module, PyTorchModelHubMixin):
def __init__(self, num_classes=10):
super(CustomCIFAR10Net, self).__init__()
self.intermediate_blocks = nn.ModuleList([
IntermediateBlock(3, 3, [64, 3, 3, 1, 1]),
IntermediateBlock(64, 3, [128, 3, 3, 1, 1]),
IntermediateBlock(128, 3, [256, 3, 3, 1, 1]),
IntermediateBlock(256, 3, [512, 3, 3, 1, 1]),
IntermediateBlock(512, 3, [1024, 3, 3, 1, 1])
])
self.output_block = OutputBlock(1024, num_classes, [512, 256])
self.dropout = nn.Dropout(0.5)
for m in self.modules():
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu')
def forward(self, x):
for block in self.intermediate_blocks:
x = block(x)
x = self.dropout(x)
x = self.output_block(x)
return x