File size: 3,423 Bytes
e8e478e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import torch
import torch.nn as nn
from typing import Union, List, Dict, Any, cast
import torchvision
import torch.nn.functional as F
class VGG(torch.nn.Module):
def __init__(self, arch_type, pretrained, progress):
super().__init__()
self.layer1 = torch.nn.Sequential()
self.layer2 = torch.nn.Sequential()
self.layer3 = torch.nn.Sequential()
self.layer4 = torch.nn.Sequential()
self.layer5 = torch.nn.Sequential()
if arch_type == 'vgg11':
official_vgg = torchvision.models.vgg11(pretrained=pretrained, progress=progress)
blocks = [ [0,2], [2,5], [5,10], [10,15], [15,20] ]
last_idx = 20
elif arch_type == 'vgg19':
official_vgg = torchvision.models.vgg19(pretrained=pretrained, progress=progress)
blocks = [ [0,4], [4,9], [9,18], [18,27], [27,36] ]
last_idx = 36
else:
raise NotImplementedError
for x in range( *blocks[0] ):
self.layer1.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[1] ):
self.layer2.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[2] ):
self.layer3.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[3] ):
self.layer4.add_module(str(x), official_vgg.features[x])
for x in range( *blocks[4] ):
self.layer5.add_module(str(x), official_vgg.features[x])
self.max_pool = official_vgg.features[last_idx]
self.avgpool = nn.AdaptiveAvgPool2d((7, 7))
self.fc1 = official_vgg.classifier[0]
self.fc2 = official_vgg.classifier[3]
self.fc3 = official_vgg.classifier[6]
self.dropout = nn.Dropout()
def forward(self, x):
out = {}
x = self.layer1(x)
out['f0'] = x
x = self.layer2(x)
out['f1'] = x
x = self.layer3(x)
out['f2'] = x
x = self.layer4(x)
out['f3'] = x
x = self.layer5(x)
out['f4'] = x
x = self.max_pool(x)
x = self.avgpool(x)
x = x.view(-1,512*7*7)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout(x)
x = self.fc2(x)
x = F.relu(x)
out['penultimate'] = x
x = self.dropout(x)
x = self.fc3(x)
out['logits'] = x
return out
def vgg11(pretrained=False, progress=True):
r"""VGG 11-layer model (configuration "A") from
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return VGG('vgg11', pretrained, progress)
def vgg19(pretrained=False, progress=True):
r"""VGG 19-layer model (configuration "E")
`"Very Deep Convolutional Networks For Large-Scale Image Recognition" <https://arxiv.org/pdf/1409.1556.pdf>`_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
"""
return VGG('vgg19', pretrained, progress)
|