jcarnero commited on
Commit
57f6a10
1 Parent(s): b26fc68

deployment model implementation

Browse files
Files changed (1) hide show
  1. deployment/model.py +64 -0
deployment/model.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ from torch import nn
3
+ from timm import create_model
4
+
5
+
6
+ def get_model() -> nn.Sequential:
7
+ net = create_model(
8
+ "vit_tiny_patch16_224", pretrained=False, num_classes=0, in_chans=3
9
+ )
10
+
11
+ head = nn.Sequential(
12
+ nn.BatchNorm1d(192),
13
+ nn.Dropout(0.25),
14
+ nn.Linear(192, 512, bias=False),
15
+ nn.ReLU(inplace=True),
16
+ nn.BatchNorm1d(512),
17
+ nn.Dropout(0.5),
18
+ nn.Linear(512, 37, bias=False),
19
+ )
20
+
21
+ return nn.Sequential(net, head)
22
+
23
+
24
+ def copy_weight(name, parameter, state_dict):
25
+ """
26
+ Takes in a layer `name`, model `parameter`, and `state_dict`
27
+ and loads the weights from `state_dict` into `parameter`
28
+ if it exists.
29
+ """
30
+ # Part of the body
31
+ if name[0] == "0":
32
+ name = name[:2] + "model." + name[2:]
33
+ if name in state_dict.keys():
34
+ input_parameter = state_dict[name]
35
+ if input_parameter.shape == parameter.shape:
36
+ parameter.copy_(input_parameter)
37
+ else:
38
+ print(f"Shape mismatch at layer: {name}, skipping")
39
+ else:
40
+ print(f"{name} is not in the state_dict, skipping.")
41
+
42
+
43
+ def apply_weights(
44
+ input_model: nn.Module,
45
+ input_weights: collections.OrderedDict,
46
+ application_function: callable,
47
+ ):
48
+ """
49
+ Takes an input state_dict and applies those weights to the `input_model`, potentially
50
+ with a modifier function.
51
+
52
+ Args:
53
+ input_model (`nn.Module`):
54
+ The model that weights should be applied to
55
+ input_weights (`collections.OrderedDict`):
56
+ A dictionary of weights, the trained model's `state_dict()`
57
+ application_function (`callable`):
58
+ A function that takes in one parameter and layer name from `input_model`
59
+ and the `input_weights`. Should apply the weights from the state dict into `input_model`.
60
+ """
61
+ model_dict = input_model.state_dict()
62
+ for name, parameter in model_dict.items():
63
+ application_function(name, parameter, input_weights)
64
+ input_model.load_state_dict(model_dict)