Spaces:
Build error
Build error
import torchvision | |
def create_fasterrcnn_model(num_classes): | |
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True) | |
for param in model.parameters(): | |
param.requires_grad = False | |
# Redefine the ROI head: | |
in_features = model.roi_heads.box_predictor.cls_score.in_features | |
# model.roi_heads.box_predictor = model.roi_heads.box_predictor(in_features, out_features =len(categories)) | |
model.roi_heads.box_predictor = torchvision.models.detection.faster_rcnn.FastRCNNPredictor(in_features, num_classes) # FastRCNNPredictor(in_features, num_classes=len(categories)) | |
return model |