AlexKoff88 commited on
Commit
09ca532
·
1 Parent(s): 8cea71d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +37 -1
README.md CHANGED
@@ -7,4 +7,40 @@ The model checkpoint trained using https://github.com/AlexKoff88/mobilenetv2_foo
7
  The main intend is to use it in samples and demos for model optimization. Here is the advantages:
8
  - FOOD101 can automatically downloaded without registration and SMS.
9
  - It is quite representative to reflect the real world scenarios.
10
- - MobileNet v2 is easy to train and lightweight model which is also representative and used in many public benchmarks.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  The main intend is to use it in samples and demos for model optimization. Here is the advantages:
8
  - FOOD101 can automatically downloaded without registration and SMS.
9
  - It is quite representative to reflect the real world scenarios.
10
+ - MobileNet v2 is easy to train and lightweight model which is also representative and used in many public benchmarks.
11
+
12
+ Here is the code to load the checkpoint in PyTorch:
13
+
14
+ ```python
15
+ import sys
16
+ import os
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torchvision.models as models
21
+
22
+ FOOD101_CLASSES = 101
23
+
24
+ def fix_names(state_dict):
25
+ state_dict = {key.replace('module.', ''): value for (key, value) in state_dict.items()}
26
+ return state_dict
27
+
28
+ model = models.mobilenet_v2()
29
+
30
+ num_ftrs = model.classifier[1].in_features
31
+ model.classifier[1] = nn.Linear(num_ftrs, FOOD101_CLASSES)
32
+
33
+ if len(sys.argv) > 1:
34
+ checkpoint_path = sys.argv[1]
35
+
36
+ if os.path.isfile(checkpoint_path):
37
+ print("=> loading checkpoint '{}'".format(checkpoint_path))
38
+
39
+
40
+ checkpoint = torch.load(checkpoint_path)
41
+ weights = fix_names(checkpoint['state_dict'])
42
+ model.load_state_dict(weights)
43
+
44
+ print("=> loaded checkpoint '{}' (epoch {})"
45
+ .format(checkpoint_path, checkpoint['epoch']))
46
+ ```