Spaces:
Running
on
Zero
Running
on
Zero
File size: 698 Bytes
3f52b3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
import torch.nn as nn
import torchvision.models as models
class VGG_16(nn.Module):
def __init__(self):
super(VGG_16, self).__init__()
self.model = models.vgg16(weights='DEFAULT').features[:30]
for i, _ in enumerate(self.model):
if i in [4, 9, 16, 23]:
self.model[i] = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)
def forward(self, x):
features = []
for i, layer in enumerate(self.model):
x = layer(x)
if i in [0, 5, 10, 17, 24]:
features.append(x)
return features
if __name__ == '__main__':
model = VGG_16()
print(model) |