ocean_faster_RCNN / model.py
oliverlevn's picture
first-commit
76d828d
raw
history blame contribute delete
631 Bytes
import torchvision
def create_fasterrcnn_model(num_classes):
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for param in model.parameters():
param.requires_grad = False
# Redefine the ROI head:
in_features = model.roi_heads.box_predictor.cls_score.in_features
# model.roi_heads.box_predictor = model.roi_heads.box_predictor(in_features, out_features =len(categories))
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) # FastRCNNPredictor(in_features, num_classes=len(categories))
return model