File size: 1,487 Bytes
71eda4d 8cea71d 821cd0e be558de 8cea71d 09ca532 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
---
license: apache-2.0
---
MobileNet V2 model from Torchvision fine-tuned for FOOD101 dataset. Checkpoint trained for 30 epoches using https://github.com/AlexKoff88/mobilenetv2_food101.
Top-1 accuracy is 76.3% but one can do better.
The main intend is to use it in samples and demos for model optimization. Here is the advantages:
- FOOD101 can automatically downloaded without registration and SMS.
- It is quite representative to reflect the real world scenarios.
- MobileNet v2 is easy to train and lightweight model which is also representative and used in many public benchmarks.
Here is the code to load the checkpoint in PyTorch:
```python
import sys
import os
import torch
import torch.nn as nn
import torchvision.models as models
FOOD101_CLASSES = 101
def fix_names(state_dict):
state_dict = {key.replace('module.', ''): value for (key, value) in state_dict.items()}
return state_dict
model = models.mobilenet_v2()
num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, FOOD101_CLASSES)
if len(sys.argv) > 1:
checkpoint_path = sys.argv[1]
if os.path.isfile(checkpoint_path):
print("=> loading checkpoint '{}'".format(checkpoint_path))
checkpoint = torch.load(checkpoint_path)
weights = fix_names(checkpoint['state_dict'])
model.load_state_dict(weights)
print("=> loaded checkpoint '{}' (epoch {})"
.format(checkpoint_path, checkpoint['epoch']))
``` |