import torch import torch.nn as nn # import torch.nn.utils.prune as prune import torchvision.models as models import torchvision # from torchsummary import summary class MobileNetV2FeatureExtractor(nn.Module): def __init__(self): super(MobileNetV2FeatureExtractor, self).__init__() self.model = torchvision.models.detection.fasterrcnn_resnet50_fpn_v2(pretrained=False) for param in self.model.parameters(): param.requires_grad = True self.model = self.model.backbone def forward(self, x): return self.model(x) class GlobalAvgPool2D(nn.Module): def __init__(self): super(GlobalAvgPool2D, self).__init__() def forward(self, x): tensor = x['0'] return torch.mean(tensor.view(tensor.size(0), tensor.size(1), -1), dim=2) class LDRNet_fasterrcnn(nn.Module): def __init__(self, points_size=100, classification_list=[1]): super(LDRNet_fasterrcnn, self).__init__() self.points_size = points_size self.classification_list = classification_list self.backbone = MobileNetV2FeatureExtractor() if len(classification_list) > 0: class_size = sum(self.classification_list) else: class_size = 0 self.global_pool = GlobalAvgPool2D() # self.dropout = nn.Dropout(p=0.3) self.corner = nn.Linear(256, 8) self.border = nn.Linear(256, (points_size - 4) * 2) self.cls = nn.Linear(256, class_size + len(self.classification_list)) def forward(self, x): x = self.backbone(x) x = self.global_pool(x) # x = self.dropout(x) corner_output = self.corner(x) border_output = self.border(x) cls_output = self.cls(x) return corner_output, border_output, cls_output if __name__ == "__main__": import torch # from torchsummary import summary xx = torch.zeros((1, 3, 224, 224)) model = LDRNet_fasterrcnn() print(model) y = model(xx) for name, module in model.named_modules(): if isinstance(module, torch.nn.Conv2d): prune.l1_unstructured(module, name='weight', amount=0.2) elif isinstance(module, torch.nn.Linear): prune.l1_unstructured(module, name='weight', amount=0.4) # print(y[0].detach().numpy()[0]) # summary(model,input_size=(3, 224, 224)) total_params = sum(p.numel() for p in model.parameters()) total_trainable_params = sum( p.numel() for p in model.parameters() if p.requires_grad ) print(f"[INFO]: {total_params:,} total parameters.") print(f"[INFO]: {total_trainable_params:,} trainable parameters.")