fewshot_random_subspace / models /ResNet12_embedding.py
import timm
import numpy as np
import torch.nn as nn
import torch
import torch.nn.functional as F
from models.dropblock import DropBlock
from models.vit import ViT
from torchvision import models
import random
# This ResNet network was designed following the practice of the following papers:
# TADAM: Task dependent adaptive metric for improved few-shot learning (Oreshkin et al., in NIPS 2018) and
# A Simple Neural Attentive Meta-Learner (Mishra et al., in ICLR 2018).
def conv3x3(in_planes, out_planes, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
padding=1, bias=False)
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None, drop_rate=0.0, drop_block=False, block_size=1):
super(BasicBlock, self).__init__()
self.conv1 = conv3x3(inplanes, planes)
self.bn1 = nn.BatchNorm2d(planes)
self.relu = nn.LeakyReLU(0.1)
self.conv2 = conv3x3(planes, planes)
self.bn2 = nn.BatchNorm2d(planes)
self.conv3 = conv3x3(planes, planes)
self.bn3 = nn.BatchNorm2d(planes)
self.maxpool = nn.MaxPool2d(stride)
self.downsample = downsample
self.stride = stride
self.drop_rate = drop_rate
self.num_batches_tracked = 0
self.drop_block = drop_block
self.block_size = block_size
self.DropBlock = DropBlock(block_size=self.block_size)
def forward(self, x):
self.num_batches_tracked += 1
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out = self.relu(out)
out = self.conv3(out)
out = self.bn3(out)
if self.downsample is not None:
residual = self.downsample(x)
out += residual
out = self.relu(out)
out = self.maxpool(out)
if self.drop_rate > 0:
if self.drop_block == True:
feat_size = out.size()[2]
keep_rate = max(1.0 - self.drop_rate / (20*2000) *
(self.num_batches_tracked), 1.0 - self.drop_rate)
gamma = (1 - keep_rate) / self.block_size**2 * \
feat_size**2 / (feat_size - self.block_size + 1)**2
out = self.DropBlock(out, gamma=gamma)
out = F.dropout(out, p=self.drop_rate,
training=self.training, inplace=True)
return out
class ResNet(nn.Module):
def __init__(self, block, keep_prob=1.0, avg_pool=False, drop_rate=0.0, dropblock_size=5):
self.inplanes = 3
super(ResNet, self).__init__()
self.layer1 = self._make_layer(
block, 64, stride=2, drop_rate=drop_rate)
self.layer2 = self._make_layer(
block, 160, stride=2, drop_rate=drop_rate)
self.layer3 = self._make_layer(
block, 320, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
self.layer4 = self._make_layer(
block, 640, stride=2, drop_rate=drop_rate, drop_block=True, block_size=dropblock_size)
if avg_pool:
self.avgpool = nn.AvgPool2d(5, stride=1)
self.keep_prob = keep_prob
self.keep_avg_pool = avg_pool
self.dropout = nn.Dropout(p=1 - self.keep_prob, inplace=False)
self.drop_rate = drop_rate
self.linear1_1 = nn.Linear(2560, 512)
self.linear1_2 = nn.Linear(512, 64)
self.linear2_1 = nn.Linear(2560, 512)
self.linear2_2 = nn.Linear(512, 64)
self.linear3_1 = nn.Linear(2560, 512)
self.linear3_2 = nn.Linear(512, 64)
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight, mode='fan_out', nonlinearity='leaky_relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, planes, stride=1, drop_rate=0.0, drop_block=False, block_size=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(self.inplanes, planes * block.expansion,
kernel_size=1, stride=1, bias=False),
nn.BatchNorm2d(planes * block.expansion),
layers = []
layers.append(block(self.inplanes, planes, stride,
downsample, drop_rate, drop_block, block_size))
self.inplanes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
if self.keep_avg_pool:
x = self.avgpool(x)
x = x.view(x.size(0), -1)
# x1 = F.relu(self.linear1_1(x))
# x1 = self.linear1_2(x1)
# x2 = F.relu(self.linear2_1(x))
# x2 = self.linear2_2(x2)
# x3 = F.relu(self.linear3_1(x))
# x3 = self.linear3_2(x3)
# return [x1, x2, x3]
return x
# return x
class custom_model(nn.Module):
def __init__(self,num_layer=5):
super(custom_model, self).__init__()
# self.classifier = timm.create_model('densenet121', pretrained=True)
self.classifier = timm.create_model('tf_efficientnet_b7_ns', pretrained=True)
# self.classifier = nn.Sequential(*list(classifier.children())[:-1])
self.num_layer = num_layer
# self.classifier = models.resnet34(pretrained=True, progress=True)
# self.classifier = ViT(
# image_size = 32,
# patch_size = 3,
# dim = 512,
# depth = 6,
# heads = 8,
# mlp_dim = 1000,
# dropout = 0.1,
# emb_dropout = 0.1
# )
# self.bn1 = nn.BatchNorm1d(num_features=1000)
self.bn2 = nn.BatchNorm1d(num_features=128)
self.dropout = nn.Dropout(0.4)
for i in range(num_layer):
setattr(self, "linear%d_1" % i, nn.Linear(1000,512))
setattr(self, "batch_norm%d_1" % i, nn.BatchNorm1d(num_features=512))
setattr(self, "linear%d_2" % i, nn.Linear(512,64))
# setattr(self, "linear%d_3" % i, nn.Linear(128,64))
for m in self.modules():
# if isinstance(m, nn.Conv2d):
# nn.init.kaiming_normal_(
# m.weight, mode='fan_out', nonlinearity='leaky_relu')
# elif isinstance(m, nn.BatchNorm2d):
# nn.init.constant_(m.weight, 1)
# nn.init.constant_(m.bias, 0)
if isinstance(m, nn.Linear):
# m.weight = nn.parameter.Parameter(torch.randn(m.out_features,m.in_features) * torch.sqrt(torch.tensor(2/m.in_features,requires_grad = True)))
y = m.in_features
y = random.randint(m.in_features/2,m.in_features)
m.weight.data.normal_(0.0, 1/np.sqrt(y))
# m.bias.data should be 0
# m.bias.data.fill_(0)
self.act = nn.ReLU()
def forward(self, x):
x = self.classifier(x)
feat = []
# print(self.classifier.classifier.weight)
# feat = torch.zeros((self.num_layer,256))
for i in range(self.num_layer):
x1 = self.act(getattr(self,"batch_norm%d_1" % i)(getattr(self, "linear%d_1" % i)(x)))
x2 = getattr(self, "linear%d_2" % i)(x1)
# print(getattr(self, "linear%d_1" % i).weight)
# weights.append(getattr(self, "linear%d_1" % i).weight)
feat = torch.stack(feat ,dim = 0)
# weights = torch.stack(weights,dim = 0)
# print(feat.size())
return feat
def conv_block(in_channels, out_channels):
returns a block conv-bn-relu-pool
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 3, padding=1),
class ProtoNet(nn.Module):
Model as described in the reference paper,
source: https://github.com/jakesnell/prototypical-networks/blob/f0c48808e496989d01db59f86d4449d7aee9ab0c/protonets/models/few_shot.py#L62-L84
def __init__(self, x_dim=3, hid_dim=64, z_dim=64):
super(ProtoNet, self).__init__()
self.encoder = nn.Sequential(
conv_block(x_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, hid_dim),
conv_block(hid_dim, z_dim),
def forward(self, x):
x = self.encoder(x)
return [x.view(x.size(0), -1)]
def resnet12(keep_prob=1.0, avg_pool=False,num_layer=5, **kwargs):
"""Constructs a ResNet-12 model.
# model = ResNet(BasicBlock, keep_prob=keep_prob,
# avg_pool=avg_pool, **kwargs)
model = custom_model(num_layer=num_layer)
# model = ProtoNet()
return model