Spaces:
Runtime error
Runtime error
Commit
·
ee240b7
1
Parent(s):
51b7d90
fixed state_dict
Browse files- __pycache__/model.cpython-310.pyc +0 -0
- app.py +11 -1
- model.py +2 -3
- model_food101_20_percent.pth +1 -1
__pycache__/model.cpython-310.pyc
ADDED
Binary file (1.11 kB). View file
|
|
app.py
CHANGED
@@ -18,7 +18,15 @@ vit16, vit16_transforms = create_vit16_model(
|
|
18 |
num_classes=101, # could also use len(class_names)
|
19 |
)
|
20 |
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
vit16.load_state_dict(
|
23 |
torch.load(
|
24 |
f="model_food101_20_percent.pth",
|
@@ -26,6 +34,8 @@ vit16.load_state_dict(
|
|
26 |
)
|
27 |
)
|
28 |
|
|
|
|
|
29 |
### 3. Predict function ###
|
30 |
|
31 |
# Create predict function
|
|
|
18 |
num_classes=101, # could also use len(class_names)
|
19 |
)
|
20 |
|
21 |
+
|
22 |
+
state_dict = torch.load("model_food101_20_percent.pth")
|
23 |
+
state_dict["heads.0.weight"] = state_dict.pop("heads.weight")
|
24 |
+
state_dict["heads.0.bias"] = state_dict.pop("heads.bias")
|
25 |
+
# save new state_dict in .pth
|
26 |
+
torch.save(state_dict, "model_food101_20_percent.pth")
|
27 |
+
|
28 |
+
|
29 |
+
|
30 |
vit16.load_state_dict(
|
31 |
torch.load(
|
32 |
f="model_food101_20_percent.pth",
|
|
|
34 |
)
|
35 |
)
|
36 |
|
37 |
+
|
38 |
+
|
39 |
### 3. Predict function ###
|
40 |
|
41 |
# Create predict function
|
model.py
CHANGED
@@ -28,8 +28,7 @@ def create_vit16_model(num_classes:int=101,
|
|
28 |
|
29 |
# Change classifier head with random seed for reproducibility
|
30 |
torch.manual_seed(seed)
|
31 |
-
model.
|
32 |
-
|
33 |
-
)
|
34 |
|
35 |
return model, transforms
|
|
|
28 |
|
29 |
# Change classifier head with random seed for reproducibility
|
30 |
torch.manual_seed(seed)
|
31 |
+
model.heads = nn.Sequential(nn.Linear(in_features=768, # keep this the same as original model
|
32 |
+
out_features=num_classes)) # update to reflect target number of classes
|
|
|
33 |
|
34 |
return model, transforms
|
model_food101_20_percent.pth
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 343564561
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b4357a334ed5737baacaf7a99b0ba491ef88c61580790277acec9ef877cd77c9
|
3 |
size 343564561
|