|
|
|
|
|
import logging |
|
import os |
|
from collections import OrderedDict |
|
from typing import List, Optional, Union |
|
import torch |
|
from torch import nn |
|
|
|
from detectron2.checkpoint import DetectionCheckpointer |
|
from detectron2.config import CfgNode |
|
from detectron2.engine import DefaultTrainer |
|
from detectron2.evaluation import ( |
|
DatasetEvaluator, |
|
DatasetEvaluators, |
|
inference_on_dataset, |
|
print_csv_format, |
|
) |
|
from detectron2.solver.build import get_default_optimizer_params, maybe_add_gradient_clipping |
|
from detectron2.utils import comm |
|
from detectron2.utils.events import EventWriter, get_event_storage |
|
|
|
from densepose import DensePoseDatasetMapperTTA, DensePoseGeneralizedRCNNWithTTA, load_from_cfg |
|
from densepose.data import ( |
|
DatasetMapper, |
|
build_combined_loader, |
|
build_detection_test_loader, |
|
build_detection_train_loader, |
|
build_inference_based_loaders, |
|
has_inference_based_loaders, |
|
) |
|
from densepose.evaluation.d2_evaluator_adapter import Detectron2COCOEvaluatorAdapter |
|
from densepose.evaluation.evaluator import DensePoseCOCOEvaluator, build_densepose_evaluator_storage |
|
from densepose.modeling.cse import Embedder |
|
|
|
|
|
class SampleCountingLoader: |
|
def __init__(self, loader): |
|
self.loader = loader |
|
|
|
def __iter__(self): |
|
it = iter(self.loader) |
|
storage = get_event_storage() |
|
while True: |
|
try: |
|
batch = next(it) |
|
num_inst_per_dataset = {} |
|
for data in batch: |
|
dataset_name = data["dataset"] |
|
if dataset_name not in num_inst_per_dataset: |
|
num_inst_per_dataset[dataset_name] = 0 |
|
num_inst = len(data["instances"]) |
|
num_inst_per_dataset[dataset_name] += num_inst |
|
for dataset_name in num_inst_per_dataset: |
|
storage.put_scalar(f"batch/{dataset_name}", num_inst_per_dataset[dataset_name]) |
|
yield batch |
|
except StopIteration: |
|
break |
|
|
|
|
|
class SampleCountMetricPrinter(EventWriter): |
|
def __init__(self): |
|
self.logger = logging.getLogger(__name__) |
|
|
|
def write(self): |
|
storage = get_event_storage() |
|
batch_stats_strs = [] |
|
for key, buf in storage.histories().items(): |
|
if key.startswith("batch/"): |
|
batch_stats_strs.append(f"{key} {buf.avg(20)}") |
|
self.logger.info(", ".join(batch_stats_strs)) |
|
|
|
|
|
class Trainer(DefaultTrainer): |
|
@classmethod |
|
def extract_embedder_from_model(cls, model: nn.Module) -> Optional[Embedder]: |
|
if isinstance(model, nn.parallel.DistributedDataParallel): |
|
model = model.module |
|
if hasattr(model, "roi_heads") and hasattr(model.roi_heads, "embedder"): |
|
return model.roi_heads.embedder |
|
return None |
|
|
|
|
|
|
|
@classmethod |
|
def test( |
|
cls, |
|
cfg: CfgNode, |
|
model: nn.Module, |
|
evaluators: Optional[Union[DatasetEvaluator, List[DatasetEvaluator]]] = None, |
|
): |
|
""" |
|
Args: |
|
cfg (CfgNode): |
|
model (nn.Module): |
|
evaluators (DatasetEvaluator, list[DatasetEvaluator] or None): if None, will call |
|
:meth:`build_evaluator`. Otherwise, must have the same length as |
|
``cfg.DATASETS.TEST``. |
|
|
|
Returns: |
|
dict: a dict of result metrics |
|
""" |
|
logger = logging.getLogger(__name__) |
|
if isinstance(evaluators, DatasetEvaluator): |
|
evaluators = [evaluators] |
|
if evaluators is not None: |
|
assert len(cfg.DATASETS.TEST) == len(evaluators), "{} != {}".format( |
|
len(cfg.DATASETS.TEST), len(evaluators) |
|
) |
|
|
|
results = OrderedDict() |
|
for idx, dataset_name in enumerate(cfg.DATASETS.TEST): |
|
data_loader = cls.build_test_loader(cfg, dataset_name) |
|
|
|
|
|
if evaluators is not None: |
|
evaluator = evaluators[idx] |
|
else: |
|
try: |
|
embedder = cls.extract_embedder_from_model(model) |
|
evaluator = cls.build_evaluator(cfg, dataset_name, embedder=embedder) |
|
except NotImplementedError: |
|
logger.warn( |
|
"No evaluator found. Use `DefaultTrainer.test(evaluators=)`, " |
|
"or implement its `build_evaluator` method." |
|
) |
|
results[dataset_name] = {} |
|
continue |
|
if cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE or comm.is_main_process(): |
|
results_i = inference_on_dataset(model, data_loader, evaluator) |
|
else: |
|
results_i = {} |
|
results[dataset_name] = results_i |
|
if comm.is_main_process(): |
|
assert isinstance( |
|
results_i, dict |
|
), "Evaluator must return a dict on the main process. Got {} instead.".format( |
|
results_i |
|
) |
|
logger.info("Evaluation results for {} in csv format:".format(dataset_name)) |
|
print_csv_format(results_i) |
|
|
|
if len(results) == 1: |
|
results = list(results.values())[0] |
|
return results |
|
|
|
@classmethod |
|
def build_evaluator( |
|
cls, |
|
cfg: CfgNode, |
|
dataset_name: str, |
|
output_folder: Optional[str] = None, |
|
embedder: Optional[Embedder] = None, |
|
) -> DatasetEvaluators: |
|
if output_folder is None: |
|
output_folder = os.path.join(cfg.OUTPUT_DIR, "inference") |
|
evaluators = [] |
|
distributed = cfg.DENSEPOSE_EVALUATION.DISTRIBUTED_INFERENCE |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
evaluators.append( |
|
Detectron2COCOEvaluatorAdapter( |
|
dataset_name, output_dir=output_folder, distributed=distributed |
|
) |
|
) |
|
if cfg.MODEL.DENSEPOSE_ON: |
|
storage = build_densepose_evaluator_storage(cfg, output_folder) |
|
evaluators.append( |
|
DensePoseCOCOEvaluator( |
|
dataset_name, |
|
distributed, |
|
output_folder, |
|
evaluator_type=cfg.DENSEPOSE_EVALUATION.TYPE, |
|
min_iou_threshold=cfg.DENSEPOSE_EVALUATION.MIN_IOU_THRESHOLD, |
|
storage=storage, |
|
embedder=embedder, |
|
should_evaluate_mesh_alignment=cfg.DENSEPOSE_EVALUATION.EVALUATE_MESH_ALIGNMENT, |
|
mesh_alignment_mesh_names=cfg.DENSEPOSE_EVALUATION.MESH_ALIGNMENT_MESH_NAMES, |
|
) |
|
) |
|
return DatasetEvaluators(evaluators) |
|
|
|
@classmethod |
|
def build_optimizer(cls, cfg: CfgNode, model: nn.Module): |
|
params = get_default_optimizer_params( |
|
model, |
|
base_lr=cfg.SOLVER.BASE_LR, |
|
weight_decay_norm=cfg.SOLVER.WEIGHT_DECAY_NORM, |
|
bias_lr_factor=cfg.SOLVER.BIAS_LR_FACTOR, |
|
weight_decay_bias=cfg.SOLVER.WEIGHT_DECAY_BIAS, |
|
overrides={ |
|
"features": { |
|
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.FEATURES_LR_FACTOR, |
|
}, |
|
"embeddings": { |
|
"lr": cfg.SOLVER.BASE_LR * cfg.MODEL.ROI_DENSEPOSE_HEAD.CSE.EMBEDDING_LR_FACTOR, |
|
}, |
|
}, |
|
) |
|
optimizer = torch.optim.SGD( |
|
params, |
|
cfg.SOLVER.BASE_LR, |
|
momentum=cfg.SOLVER.MOMENTUM, |
|
nesterov=cfg.SOLVER.NESTEROV, |
|
weight_decay=cfg.SOLVER.WEIGHT_DECAY, |
|
) |
|
|
|
return maybe_add_gradient_clipping(cfg, optimizer) |
|
|
|
@classmethod |
|
def build_test_loader(cls, cfg: CfgNode, dataset_name): |
|
return build_detection_test_loader(cfg, dataset_name, mapper=DatasetMapper(cfg, False)) |
|
|
|
@classmethod |
|
def build_train_loader(cls, cfg: CfgNode): |
|
data_loader = build_detection_train_loader(cfg, mapper=DatasetMapper(cfg, True)) |
|
if not has_inference_based_loaders(cfg): |
|
return data_loader |
|
model = cls.build_model(cfg) |
|
model.to(cfg.BOOTSTRAP_MODEL.DEVICE) |
|
DetectionCheckpointer(model).resume_or_load(cfg.BOOTSTRAP_MODEL.WEIGHTS, resume=False) |
|
inference_based_loaders, ratios = build_inference_based_loaders(cfg, model) |
|
loaders = [data_loader] + inference_based_loaders |
|
ratios = [1.0] + ratios |
|
combined_data_loader = build_combined_loader(cfg, loaders, ratios) |
|
sample_counting_loader = SampleCountingLoader(combined_data_loader) |
|
return sample_counting_loader |
|
|
|
def build_writers(self): |
|
writers = super().build_writers() |
|
writers.append(SampleCountMetricPrinter()) |
|
return writers |
|
|
|
@classmethod |
|
def test_with_TTA(cls, cfg: CfgNode, model): |
|
logger = logging.getLogger("detectron2.trainer") |
|
|
|
|
|
logger.info("Running inference with test-time augmentation ...") |
|
transform_data = load_from_cfg(cfg) |
|
model = DensePoseGeneralizedRCNNWithTTA( |
|
cfg, model, transform_data, DensePoseDatasetMapperTTA(cfg) |
|
) |
|
evaluators = [ |
|
cls.build_evaluator( |
|
cfg, name, output_folder=os.path.join(cfg.OUTPUT_DIR, "inference_TTA") |
|
) |
|
for name in cfg.DATASETS.TEST |
|
] |
|
res = cls.test(cfg, model, evaluators) |
|
res = OrderedDict({k + "_TTA": v for k, v in res.items()}) |
|
return res |
|
|