# Copyright (c) OpenMMLab. All rights reserved. from typing import Sequence import torch from mmengine.dataset import COLLATE_FUNCTIONS @COLLATE_FUNCTIONS.register_module() def yolow_collate(data_batch: Sequence, use_ms_training: bool = False) -> dict: """Rewrite collate_fn to get faster training speed. Args: data_batch (Sequence): Batch of data. use_ms_training (bool): Whether to use multi-scale training. """ batch_imgs = [] batch_bboxes_labels = [] batch_masks = [] for i in range(len(data_batch)): datasamples = data_batch[i]["data_samples"] inputs = data_batch[i]["inputs"] batch_imgs.append(inputs) gt_bboxes = datasamples.gt_instances.bboxes.tensor gt_labels = datasamples.gt_instances.labels if "masks" in datasamples.gt_instances: masks = datasamples.gt_instances.masks.to( dtype=torch.bool, device=gt_bboxes.device ) batch_masks.append(masks) batch_idx = gt_labels.new_full((len(gt_labels), 1), i) bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), dim=1) batch_bboxes_labels.append(bboxes_labels) collated_results = { "data_samples": {"bboxes_labels": torch.cat(batch_bboxes_labels, 0)} } if len(batch_masks) > 0: collated_results["data_samples"]["masks"] = torch.cat(batch_masks, 0) if use_ms_training: collated_results["inputs"] = batch_imgs else: collated_results["inputs"] = torch.stack(batch_imgs, 0) if hasattr(data_batch[0]["data_samples"], "texts"): batch_texts = [meta["data_samples"].texts for meta in data_batch] collated_results["data_samples"]["texts"] = batch_texts if hasattr(data_batch[0]["data_samples"], "is_detection"): # detection flag batch_detection = [meta["data_samples"].is_detection for meta in data_batch] collated_results["data_samples"]["is_detection"] = torch.tensor(batch_detection) return collated_results