import torch import torch.nn as nn import torch.nn.functional as F from huggingface_hub import PyTorchModelHubMixin class IntermediateBlock(nn.Module): def __init__(self, in_channels, num_conv_layers, conv_params): super(IntermediateBlock, self).__init__() self.conv_layers = nn.ModuleList([nn.Conv2d(in_channels, *conv_params) for _ in range(num_conv_layers)]) self.batch_norms = nn.ModuleList([nn.BatchNorm2d(conv_params[0]) for _ in range(num_conv_layers)]) out_channels = conv_params[0] self.fc = nn.Linear(in_channels, out_channels) def forward(self, x): batch_size = x.size(0) channel_means = x.mean(dim=[2, 3]) a = self.fc(channel_means) x_out = torch.stack([F.leaky_relu(conv(x)) for conv in self.conv_layers], dim=-1).sum(dim=-1) x_out = torch.stack([bn(x_out) for bn in self.batch_norms], dim=-1).sum(dim=-1) return x_out * F.leaky_relu(a.view(batch_size, -1, 1, 1)) class OutputBlock(nn.Module): def __init__(self, in_channels, num_classes, hidden_sizes=[]): super(OutputBlock, self).__init__() self.fc_layers = nn.ModuleList([nn.Linear(in_channels, hidden_sizes[0])] + [nn.Linear(hidden_sizes[i], hidden_sizes[i+1]) for i in range(len(hidden_sizes)-1)] + [nn.Linear(hidden_sizes[-1], num_classes)]) self.batch_norms = nn.ModuleList([nn.BatchNorm1d(size) for size in hidden_sizes]) def forward(self, x): channel_means = x.mean(dim=[2, 3]) out = F.leaky_relu(channel_means) for fc, bn in zip(self.fc_layers, self.batch_norms): out = F.leaky_relu(bn(fc(out))) return out class CustomCIFAR10Net(nn.Module, PyTorchModelHubMixin): def __init__(self, num_classes=10): super(CustomCIFAR10Net, self).__init__() self.intermediate_blocks = nn.ModuleList([ IntermediateBlock(3, 3, [64, 3, 3, 1, 1]), IntermediateBlock(64, 3, [128, 3, 3, 1, 1]), IntermediateBlock(128, 3, [256, 3, 3, 1, 1]), IntermediateBlock(256, 3, [512, 3, 3, 1, 1]), IntermediateBlock(512, 3, [1024, 3, 3, 1, 1]) ]) self.output_block = OutputBlock(1024, num_classes, [512, 256]) self.dropout = nn.Dropout(0.5) for m in self.modules(): if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): nn.init.kaiming_normal_(m.weight, mode='fan_in', nonlinearity='relu') def forward(self, x): for block in self.intermediate_blocks: x = block(x) x = self.dropout(x) x = self.output_block(x) return x