File size: 1,487 Bytes
71eda4d
 
 
8cea71d
821cd0e
 
be558de
8cea71d
 
 
 
09ca532
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
---
license: apache-2.0
---

MobileNet V2 model from Torchvision fine-tuned for FOOD101 dataset. Checkpoint trained for 30 epoches using https://github.com/AlexKoff88/mobilenetv2_food101.

Top-1 accuracy is 76.3% but one can do better.

The main intend is to use it in samples and demos for model optimization. Here is the advantages:
- FOOD101 can automatically downloaded without registration and SMS.
- It is quite representative to reflect the real world scenarios.
- MobileNet v2 is easy to train and lightweight model which is also representative and used in many public benchmarks.

Here is the code to load the checkpoint in PyTorch:

```python
import sys
import os

import torch
import torch.nn as nn
import torchvision.models as models

FOOD101_CLASSES = 101

def fix_names(state_dict):
    state_dict = {key.replace('module.', ''): value for (key, value) in state_dict.items()}
    return state_dict

model = models.mobilenet_v2()    

num_ftrs = model.classifier[1].in_features
model.classifier[1] = nn.Linear(num_ftrs, FOOD101_CLASSES)

if len(sys.argv) > 1:
    checkpoint_path = sys.argv[1]

    if os.path.isfile(checkpoint_path):
        print("=> loading checkpoint '{}'".format(checkpoint_path))
        

        checkpoint = torch.load(checkpoint_path)
        weights = fix_names(checkpoint['state_dict'])
        model.load_state_dict(weights)

        print("=> loaded checkpoint '{}' (epoch {})"
                .format(checkpoint_path, checkpoint['epoch']))
```