import torchvision def create_fasterrcnn_model(num_classes): model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) for param in model.parameters(): param.requires_grad = False # Redefine the ROI head: in_features = model.roi_heads.box_predictor.cls_score.in_features # model.roi_heads.box_predictor = model.roi_heads.box_predictor(in_features, out_features =len(categories)) model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) # FastRCNNPredictor(in_features, num_classes=len(categories)) return model