Theem commited on
Commit
d55be37
·
1 Parent(s): 4ec51d3

Update readme instructions

Browse files
Files changed (1) hide show
  1. README.md +49 -0
README.md CHANGED
@@ -8,3 +8,52 @@ The COCO images were transformed to grayscale using PIL. The hyperparameters and
8
 
9
  Can be used as pretrained model for multispectral imaging as suggested in this [paper](https://ceur-ws.org/Vol-2771/AICS2020_paper_50.pdf).
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  Can be used as pretrained model for multispectral imaging as suggested in this [paper](https://ceur-ws.org/Vol-2771/AICS2020_paper_50.pdf).
10
 
11
+ The file is given as a state_dict. Thus to initialize the model run:
12
+
13
+ ```
14
+ # Load pretrained weights
15
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model']
16
+ # Load torchvision model
17
+ model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights='DEFAULT')
18
+ # Adapt input convolution
19
+ model.backbone.body.conv1 = torch.nn.Conv2d(1, 64,
20
+ kernel_size=(7, 7), stride=(2, 2),
21
+ padding=(3, 3), bias=False).requires_grad_(True)
22
+ model.load_state_dict(state_dict)
23
+ ```
24
+
25
+ If its going to be used for multispectral data, edit the first layer and duplicate the weights:
26
+
27
+ ```
28
+ state_dict = torch.load(model_path, map_location=torch.device('cpu'))['model']
29
+
30
+ # Duplicate the weights
31
+ conv1_weight = state_dict['backbone.body.conv1.weight']
32
+ conv1_type = conv1_weight.dtype
33
+ conv1_weight = conv1_weight.float()
34
+ repeat = int(math.ceil(in_chans / 3))
35
+ conv1_weight = conv1_weight.repeat(1, repeat, 1, 1)[:, :in_chans, :, :]
36
+ #conv1_weight *= (3 / float(in_chans))
37
+ conv1_weight = conv1_weight.to(conv1_type)
38
+ state_dict['backbone.body.conv1.weight'] = conv1_weight
39
+
40
+ model.backbone.body.conv1 = torch.nn.Conv2d(in_chans, 64,
41
+ kernel_size=(7, 7), stride=(2, 2),
42
+ padding=(3, 3), bias=False).requires_grad_(True)
43
+ model.load_state_dict(state_dict)
44
+ ```
45
+
46
+ For Faster-RCNN the input transform may need to be adapted. Here is an example:
47
+
48
+ ```
49
+ coco_mean = [0.5] * in_chans
50
+ coco_std = [0.25] * in_chans
51
+ if in_chans > 3:
52
+ coco_mean[:3] = [0.485, 0.456, 0.406]
53
+ coco_std[:3] = [0.229, 0.224, 0.225]
54
+ transform = torchvision.models.detection.transform.GeneralizedRCNNTransform(min_size=800,
55
+ max_size=1333,
56
+ image_mean=coco_mean,
57
+ image_std=coco_std)
58
+ ```
59
+