# Copyright (c) OpenMMLab. All rights reserved. from typing import List, Optional import torch from mmdet.models.losses.utils import weight_reduce_loss from torch import Tensor from torch import nn as nn from mmdet3d.registry import MODELS from ..layers import PAConv, PAConvCUDA def weight_correlation(conv: nn.Module) -> Tensor: """Calculate correlations between kernel weights in Conv's weight bank as regularization loss. The cosine similarity is used as metrics. Args: conv (nn.Module): A Conv modules to be regularized. Currently we only support `PAConv` and `PAConvCUDA`. Returns: Tensor: Correlations between each kernel weights in weight bank. """ assert isinstance(conv, (PAConv, PAConvCUDA)), \ f'unsupported module type {type(conv)}' kernels = conv.weight_bank # [C_in, num_kernels * C_out] in_channels = conv.in_channels out_channels = conv.out_channels num_kernels = conv.num_kernels # [num_kernels, Cin * Cout] flatten_kernels = kernels.view(in_channels, num_kernels, out_channels).\ permute(1, 0, 2).reshape(num_kernels, -1) # [num_kernels, num_kernels] inner_product = torch.matmul(flatten_kernels, flatten_kernels.T) # [num_kernels, 1] kernel_norms = torch.sum(flatten_kernels**2, dim=-1, keepdim=True)**0.5 # [num_kernels, num_kernels] kernel_norms = torch.matmul(kernel_norms, kernel_norms.T) cosine_sims = inner_product / kernel_norms # take upper triangular part excluding diagonal since we only compute # correlation between different kernels once # the square is to ensure positive loss, refer to: # https://github.com/CVMI-Lab/PAConv/blob/main/scene_seg/tool/train.py#L208 corr = torch.sum(torch.triu(cosine_sims, diagonal=1)**2) return corr def paconv_regularization_loss(modules: List[nn.Module], reduction: str) -> Tensor: """Computes correlation loss of PAConv weight kernels as regularization. Args: modules (List[nn.Module] | :obj:`generator`): A list or a python generator of torch.nn.Modules. reduction (str): Method to reduce losses among PAConv modules. The valid reduction method are 'none', 'sum' or 'mean'. Returns: Tensor: Correlation loss of kernel weights. """ corr_loss = [] for module in modules: if isinstance(module, (PAConv, PAConvCUDA)): corr_loss.append(weight_correlation(module)) corr_loss = torch.stack(corr_loss) # perform reduction corr_loss = weight_reduce_loss(corr_loss, reduction=reduction) return corr_loss @MODELS.register_module() class PAConvRegularizationLoss(nn.Module): """Calculate correlation loss of kernel weights in PAConv's weight bank. This is used as a regularization term in PAConv model training. Args: reduction (str): Method to reduce losses. The reduction is performed among all PAConv modules instead of prediction tensors. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to 'mean'. loss_weight (float): Weight of loss. Defaults to 1.0. """ def __init__(self, reduction: str = 'mean', loss_weight: float = 1.0) -> None: super(PAConvRegularizationLoss, self).__init__() assert reduction in ['none', 'sum', 'mean'] self.reduction = reduction self.loss_weight = loss_weight def forward(self, modules: List[nn.Module], reduction_override: Optional[str] = None, **kwargs) -> Tensor: """Forward function of loss calculation. Args: modules (List[nn.Module] | :obj:`generator`): A list or a python generator of torch.nn.Modules. reduction_override (str, optional): Method to reduce losses. The valid reduction method are 'none', 'sum' or 'mean'. Defaults to None. Returns: Tensor: Correlation loss of kernel weights. """ assert reduction_override in (None, 'none', 'mean', 'sum') reduction = ( reduction_override if reduction_override else self.reduction) return self.loss_weight * paconv_regularization_loss( modules, reduction=reduction)