Spaces:
Running
Running
no COCO model loaded
Browse files- modules/train.py +2 -5
modules/train.py
CHANGED
@@ -37,10 +37,7 @@ def get_arrow_model(num_classes, num_keypoints=2):
|
|
37 |
"""
|
38 |
# Load a model pre-trained on COCO, initialized without pre-trained weights
|
39 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
40 |
-
|
41 |
-
model = keypointrcnn_resnet50_fpn(weights=KeypointRCNN_ResNet50_FPN_Weights.COCO_V1)
|
42 |
-
else:
|
43 |
-
model = keypointrcnn_resnet50_fpn(weights=None)
|
44 |
|
45 |
# Get the number of input features for the classifier in the box predictor.
|
46 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
@@ -66,7 +63,7 @@ def get_faster_rcnn_model(num_classes):
|
|
66 |
- model (torch.nn.Module): The modified Faster R-CNN model.
|
67 |
"""
|
68 |
# Load a pre-trained Faster R-CNN model
|
69 |
-
model = fasterrcnn_resnet50_fpn(weights=
|
70 |
|
71 |
# Get the number of input features for the classifier in the box predictor
|
72 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
|
37 |
"""
|
38 |
# Load a model pre-trained on COCO, initialized without pre-trained weights
|
39 |
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
|
40 |
+
model = keypointrcnn_resnet50_fpn(weights=None)
|
|
|
|
|
|
|
41 |
|
42 |
# Get the number of input features for the classifier in the box predictor.
|
43 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
|
|
63 |
- model (torch.nn.Module): The modified Faster R-CNN model.
|
64 |
"""
|
65 |
# Load a pre-trained Faster R-CNN model
|
66 |
+
model = fasterrcnn_resnet50_fpn(weights=None)
|
67 |
|
68 |
# Get the number of input features for the classifier in the box predictor
|
69 |
in_features = model.roi_heads.box_predictor.cls_score.in_features
|