ybbwcwaps
AI Video
3cc4a06
raw
history blame contribute delete
No virus
3.42 kB
import torch
import torch.nn as nn
from typing import Union, List, Dict, Any, cast
import torchvision
import torch.nn.functional as F
class VGG(torch.nn.Module):
def __init__(self, arch_type, pretrained, progress):
super().__init__()
self.layer1 = torch.nn.Sequential()
self.layer2 = torch.nn.Sequential()
self.layer3 = torch.nn.Sequential()
self.layer4 = torch.nn.Sequential()
self.layer5 = torch.nn.Sequential()
if arch_type == 'vgg11':
official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress)
blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ]
last_idx = 20
elif arch_type == 'vgg19':
official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress)
blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ]
last_idx = 36
else:
raise NotImplementedError
for x in range( *blocks[0] ):
self.layer1.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[1] ):
self.layer2.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[2] ):
self.layer3.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[3] ):
self.layer4.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[4] ):
self.layer5.add_module(str(x), official_vgg.features[x])
self.max_pool = official_vgg.features[last_idx]
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.fc1 = official_vgg.classifier[0]
self.fc2 = official_vgg.classifier[3]
self.fc3 = official_vgg.classifier[6]
self.dropout = nn.Dropout()
def forward(self, x):
out = {}
x = self.layer1(x)
out['f0'] = x
x = self.layer2(x)
out['f1'] = x
x = self.layer3(x)
out['f2'] = x
x = self.layer4(x)
out['f3'] = x
x = self.layer5(x)
out['f4'] = x
x = self.max_pool(x)
x = self.avgpool(x)
x = x.view(-1,512*7*7)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
out['penultimate'] = x
x = self.dropout(x)
x = self.fc3(x)
out['logits'] = x
return out
def vgg11(pretrained=False, progress=True):
r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return VGG('vgg11', pretrained, progress)
def vgg19(pretrained=False, progress=True):
r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return VGG('vgg19', pretrained, progress)