BenjiELCA commited on
Commit
5950696
·
1 Parent(s): cbbc2ce

no COCO model loaded

Browse files
Files changed (1) hide show
  1. modules/train.py +2 -5
modules/train.py CHANGED
@@ -37,10 +37,7 @@ def get_arrow_model(num_classes, num_keypoints=2):
37
  """
38
  # Load a model pre-trained on COCO, initialized without pre-trained weights
39
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
40
- if device == torch.device('cuda'):
41
- model = keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1)
42
- else:
43
- model = keypointrcnn_resnet50_fpn(weights=None)
44
 
45
  # Get the number of input features for the classifier in the box predictor.
46
  in_features = model.roi_heads.box_predictor.cls_score.in_features
@@ -66,7 +63,7 @@ def get_faster_rcnn_model(num_classes):
66
  - model (torch.nn.Module): The modified Faster R-CNN model.
67
  """
68
  # Load a pre-trained Faster R-CNN model
69
- model = fasterrcnn_resnet50_fpn(weights=FasterRCNN_ResNet50_FPN_Weights.COCO_V1)
70
 
71
  # Get the number of input features for the classifier in the box predictor
72
  in_features = model.roi_heads.box_predictor.cls_score.in_features
 
37
  """
38
  # Load a model pre-trained on COCO, initialized without pre-trained weights
39
  device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
40
+ model = keypointrcnn_resnet50_fpn(weights=None)
 
 
 
41
 
42
  # Get the number of input features for the classifier in the box predictor.
43
  in_features = model.roi_heads.box_predictor.cls_score.in_features
 
63
  - model (torch.nn.Module): The modified Faster R-CNN model.
64
  """
65
  # Load a pre-trained Faster R-CNN model
66
+ model = fasterrcnn_resnet50_fpn(weights=None)
67
 
68
  # Get the number of input features for the classifier in the box predictor
69
  in_features = model.roi_heads.box_predictor.cls_score.in_features