|
|
|
|
|
from typing import Any, List |
|
from torch import nn |
|
|
|
from detectron2.config import CfgNode |
|
from detectron2.structures import Instances |
|
|
|
from .cycle_pix2shape import PixToShapeCycleLoss |
|
from .cycle_shape2shape import ShapeToShapeCycleLoss |
|
from .embed import EmbeddingLoss |
|
from .embed_utils import CseAnnotationsAccumulator |
|
from .mask_or_segm import MaskOrSegmentationLoss |
|
from .registry import DENSEPOSE_LOSS_REGISTRY |
|
from .soft_embed import SoftEmbeddingLoss |
|
from .utils import BilinearInterpolationHelper, LossDict, extract_packed_annotations_from_matches |
|
|
|
|
|
@DENSEPOSE_LOSS_REGISTRY.register() |
|
class DensePoseCseLoss: |
|
""" """ |
|
|
|
_EMBED_LOSS_REGISTRY = { |
|
EmbeddingLoss.__name__: EmbeddingLoss, |
|
SoftEmbeddingLoss.__name__: SoftEmbeddingLoss, |
|
} |
|
|
|
def __init__(self, cfg: CfgNode): |
|
""" |
|
Initialize CSE loss from configuration options |
|
|
|
Args: |
|
cfg (CfgNode): configuration options |
|
""" |
|
self.w_segm = cfg.MODEL.ROI_DENSEPOSE_HEAD.INDEX_WEIGHTS |
|
self.w_embed = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_WEIGHT |
|
self.segm_loss = MaskOrSegmentationLoss(cfg) |
|
self.embed_loss = DensePoseCseLoss.create_embed_loss(cfg) |
|
self.do_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.ENABLED |
|
if self.do_shape2shape: |
|
self.w_shape2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.SHAPE_TO_SHAPE_CYCLE_LOSS.WEIGHT |
|
self.shape2shape_loss = ShapeToShapeCycleLoss(cfg) |
|
self.do_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.ENABLED |
|
if self.do_pix2shape: |
|
self.w_pix2shape = cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.PIX_TO_SHAPE_CYCLE_LOSS.WEIGHT |
|
self.pix2shape_loss = PixToShapeCycleLoss(cfg) |
|
|
|
@classmethod |
|
def create_embed_loss(cls, cfg: CfgNode): |
|
|
|
|
|
return cls._EMBED_LOSS_REGISTRY[cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBED_LOSS_NAME](cfg) |
|
|
|
def __call__( |
|
self, |
|
proposals_with_gt: List[Instances], |
|
densepose_predictor_outputs: Any, |
|
embedder: nn.Module, |
|
) -> LossDict: |
|
if not len(proposals_with_gt): |
|
return self.produce_fake_losses(densepose_predictor_outputs, embedder) |
|
accumulator = CseAnnotationsAccumulator() |
|
packed_annotations = extract_packed_annotations_from_matches(proposals_with_gt, accumulator) |
|
if packed_annotations is None: |
|
return self.produce_fake_losses(densepose_predictor_outputs, embedder) |
|
h, w = densepose_predictor_outputs.embedding.shape[2:] |
|
interpolator = BilinearInterpolationHelper.from_matches( |
|
packed_annotations, |
|
(h, w), |
|
) |
|
meshid_to_embed_losses = self.embed_loss( |
|
proposals_with_gt, |
|
densepose_predictor_outputs, |
|
packed_annotations, |
|
interpolator, |
|
embedder, |
|
) |
|
embed_loss_dict = { |
|
f"loss_densepose_E{meshid}": self.w_embed * meshid_to_embed_losses[meshid] |
|
for meshid in meshid_to_embed_losses |
|
} |
|
all_loss_dict = { |
|
"loss_densepose_S": self.w_segm |
|
* self.segm_loss(proposals_with_gt, densepose_predictor_outputs, packed_annotations), |
|
**embed_loss_dict, |
|
} |
|
if self.do_shape2shape: |
|
all_loss_dict["loss_shape2shape"] = self.w_shape2shape * self.shape2shape_loss(embedder) |
|
if self.do_pix2shape: |
|
all_loss_dict["loss_pix2shape"] = self.w_pix2shape * self.pix2shape_loss( |
|
proposals_with_gt, densepose_predictor_outputs, packed_annotations, embedder |
|
) |
|
return all_loss_dict |
|
|
|
def produce_fake_losses( |
|
self, densepose_predictor_outputs: Any, embedder: nn.Module |
|
) -> LossDict: |
|
meshname_to_embed_losses = self.embed_loss.fake_values( |
|
densepose_predictor_outputs, embedder=embedder |
|
) |
|
embed_loss_dict = { |
|
f"loss_densepose_E{mesh_name}": meshname_to_embed_losses[mesh_name] |
|
for mesh_name in meshname_to_embed_losses |
|
} |
|
all_loss_dict = { |
|
"loss_densepose_S": self.segm_loss.fake_value(densepose_predictor_outputs), |
|
**embed_loss_dict, |
|
} |
|
if self.do_shape2shape: |
|
all_loss_dict["loss_shape2shape"] = self.shape2shape_loss.fake_value(embedder) |
|
if self.do_pix2shape: |
|
all_loss_dict["loss_pix2shape"] = self.pix2shape_loss.fake_value( |
|
densepose_predictor_outputs, embedder |
|
) |
|
return all_loss_dict |
|
|